mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-09 15:25:17 +00:00
Compare commits
709 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9fccc86b71 | ||
|
|
74683560a7 | ||
|
|
1e4f9dd438 | ||
|
|
b9ff916494 | ||
|
|
9bf4a0cad2 | ||
|
|
c32e2a8196 | ||
|
|
873d41582f | ||
|
|
6fb7d85558 | ||
|
|
d5e3e32d58 | ||
|
|
f353a54555 | ||
|
|
1d6e2e751d | ||
|
|
cc50b63422 | ||
|
|
15ae83a15b | ||
|
|
81b369aed9 | ||
|
|
ecc850bfb7 | ||
|
|
19b4ef33e0 | ||
|
|
7ca045d8b9 | ||
|
|
25b9df478c | ||
|
|
abfca6aab2 | ||
|
|
3c71c075db | ||
|
|
9c2992bfb2 | ||
|
|
269a1c5452 | ||
|
|
22ce65ac72 | ||
|
|
a2f8f59192 | ||
|
|
51611c25d7 | ||
|
|
eb1bbaa63b | ||
|
|
30a59168d7 | ||
|
|
4c8026ac3d | ||
|
|
8aeb4b7d54 | ||
|
|
b2172cb047 | ||
|
|
c8884f5e25 | ||
|
|
d9c6317c84 | ||
|
|
d29ec95526 | ||
|
|
ef4508dbc8 | ||
|
|
f775e46fe2 | ||
|
|
65ad5c0c9d | ||
|
|
88bf4e77ec | ||
|
|
194f66ca9c | ||
|
|
a4f8015caa | ||
|
|
ffd129909e | ||
|
|
9332316383 | ||
|
|
6dcbbf64c3 | ||
|
|
c9aa1ff99d | ||
|
|
2ce3553612 | ||
|
|
2e14f787d4 | ||
|
|
523b41ccd2 | ||
|
|
09970dc7af | ||
|
|
d81abd401c | ||
|
|
a6cba25bc1 | ||
|
|
c6fa1d0e67 | ||
|
|
ac56e1e88b | ||
|
|
a9ee971e1c | ||
|
|
73cef3a25a | ||
|
|
9b72ea9efa | ||
|
|
9f364441e8 | ||
|
|
e49a1c07bf | ||
|
|
5364a2471d | ||
|
|
fef4fdb0eb | ||
|
|
c2bf600a39 | ||
|
|
8d9f4edf9b | ||
|
|
020e61d0da | ||
|
|
6184c43319 | ||
|
|
2cbe4a790c | ||
|
|
68b3565d7b | ||
|
|
3f385a8572 | ||
|
|
9823dc35e1 | ||
|
|
059bfee91b | ||
|
|
7beaf0eaa2 | ||
|
|
1fef90ff58 | ||
|
|
8447fd27a0 | ||
|
|
7831cba9f6 | ||
|
|
e02b2d58d5 | ||
|
|
28726632a9 | ||
|
|
0f63d973be | ||
|
|
3b26129c82 | ||
|
|
d4bb4e6624 | ||
|
|
fa2abd560a | ||
|
|
0766c49f93 | ||
|
|
a7ffc77e3d | ||
|
|
e641fde25c | ||
|
|
564c2d763e | ||
|
|
5717c7f2f4 | ||
|
|
8734d4cb90 | ||
|
|
2f6004d74a | ||
|
|
08779cc8a8 | ||
|
|
5baa753539 | ||
|
|
92fb6b012a | ||
|
|
ead98e4bca | ||
|
|
a1634909e8 | ||
|
|
8f06f6a9ed | ||
|
|
ace7c0ccb4 | ||
|
|
f87fe0a0e8 | ||
|
|
87edc6f35e | ||
|
|
1d2fe55310 | ||
|
|
c175821cc4 | ||
|
|
239a28793c | ||
|
|
c421d653e7 | ||
|
|
2542c2920d | ||
|
|
52e46ced1b | ||
|
|
cf9daf470c | ||
|
|
ac7738bdeb | ||
|
|
2d9f6c104c | ||
|
|
5d0460ece2 | ||
|
|
140d6211cc | ||
|
|
60f9a1442c | ||
|
|
cb6caf3f87 | ||
|
|
c9301a6d18 | ||
|
|
0e77e93e5d | ||
|
|
99c7abbbf1 | ||
|
|
8f511ac33c | ||
|
|
1046152119 | ||
|
|
f88228f1c5 | ||
|
|
62e2b672d9 | ||
|
|
03005b5d29 | ||
|
|
c7e8830a56 | ||
|
|
d5ef4a6d15 | ||
|
|
97b67e0e49 | ||
|
|
dd6d78cb31 | ||
|
|
46433a25f8 | ||
|
|
b4e070697d | ||
|
|
c8843edb81 | ||
|
|
f89feb881c | ||
|
|
dbba71028e | ||
|
|
8549a92e9a | ||
|
|
109cffc010 | ||
|
|
f8f3ad84fc | ||
|
|
93d7883513 | ||
|
|
015a3e8a83 | ||
|
|
bc7167e9fe | ||
|
|
384578a88c | ||
|
|
6b074653f2 | ||
|
|
65b4e1ec6c | ||
|
|
06afa29f2d | ||
|
|
6600d58ba2 | ||
|
|
25e9be3ced | ||
|
|
ccb2aaf2fe | ||
|
|
961c6f67da | ||
|
|
dc4305f75a | ||
|
|
4dc7af5a5d | ||
|
|
902bea24b4 | ||
|
|
778cf4af9e | ||
|
|
c3ef46f409 | ||
|
|
4721c58d9c | ||
|
|
aa0b63e214 | ||
|
|
3c4e7997c3 | ||
|
|
1afc3a5f65 | ||
|
|
ea3d22831e | ||
|
|
3b4d6d359b | ||
|
|
48cba39a12 | ||
|
|
bca244df67 | ||
|
|
cec4e251bd | ||
|
|
526dd866ba | ||
|
|
c29839d2ed | ||
|
|
b31ddc7bf1 | ||
|
|
22e1ad3d8a | ||
|
|
f571b1deb0 | ||
|
|
67f8732683 | ||
|
|
2b387e169b | ||
|
|
199cf480b0 | ||
|
|
18daa023cb | ||
|
|
4ad6189487 | ||
|
|
8950d92682 | ||
|
|
fe5b3c80cb | ||
|
|
0ffcce3ec8 | ||
|
|
e0ffec885c | ||
|
|
ff4ff6bc2f | ||
|
|
f4fcfc5867 | ||
|
|
7248f65c36 | ||
|
|
5c40a2db21 | ||
|
|
d6111344c5 | ||
|
|
086eb3df7a | ||
|
|
ee2976cca0 | ||
|
|
8bc6df329f | ||
|
|
bcd4d9595f | ||
|
|
5a77b7728e | ||
|
|
1fbbba6f59 | ||
|
|
847be0e99d | ||
|
|
f6a2d072e6 | ||
|
|
ed8b0f25ee | ||
|
|
6e4a602c60 | ||
|
|
2262479365 | ||
|
|
33d66959e9 | ||
|
|
7f1b2b3f6e | ||
|
|
40ee065eff | ||
|
|
a75fb6af90 | ||
|
|
72f2125668 | ||
|
|
e8f5888d8e | ||
|
|
0b06d637e7 | ||
|
|
496f6770a5 | ||
|
|
5a7e5bd870 | ||
|
|
6f8a8f8136 | ||
|
|
5df195ea82 | ||
|
|
f82f70df5c | ||
|
|
5a2bf191fc | ||
|
|
a235fb1507 | ||
|
|
0d66522ed8 | ||
|
|
b163f8ed9e | ||
|
|
83e5f60b8b | ||
|
|
5b433f962f | ||
|
|
a1da6ff5ac | ||
|
|
5977af96a0 | ||
|
|
9b33fbf1cd | ||
|
|
43652d044c | ||
|
|
b1b379ea18 | ||
|
|
21ac161b21 | ||
|
|
94e979865e | ||
|
|
6c324f2c8b | ||
|
|
e0194d8511 | ||
|
|
216dafe44b | ||
|
|
d7dc9660af | ||
|
|
e0e30df323 | ||
|
|
543dfd67e0 | ||
|
|
bbd3eafde0 | ||
|
|
e9cd355893 | ||
|
|
c3e39267b8 | ||
|
|
b477aff611 | ||
|
|
28bd1323a2 | ||
|
|
220ca45f74 | ||
|
|
70a82d80ac | ||
|
|
ac626111ac | ||
|
|
8f6740fcef | ||
|
|
d829ac4cf7 | ||
|
|
f064f6e59d | ||
|
|
5bb9c2a2bd | ||
|
|
0b5bbe9234 | ||
|
|
14c74e5e84 | ||
|
|
6448d0ee7c | ||
|
|
b0c17af2cf | ||
|
|
8f27fd5c42 | ||
|
|
a9823ba58a | ||
|
|
8cfe26f10c | ||
|
|
80db2dc254 | ||
|
|
e8e3bc8616 | ||
|
|
ab5f5386e4 | ||
|
|
bc3195c8d8 | ||
|
|
6494330c6b | ||
|
|
89e34bf1e6 | ||
|
|
2574eec2ed | ||
|
|
514b9bf9fc | ||
|
|
4d7f389b69 | ||
|
|
95f87d5669 | ||
|
|
c83365a349 | ||
|
|
6b3604cf2b | ||
|
|
af6bdca14f | ||
|
|
58d45b4d58 | ||
|
|
1906ebcfce | ||
|
|
1c773c428f | ||
|
|
e785bfcd12 | ||
|
|
47dacce6ea | ||
|
|
dcac3407ab | ||
|
|
7004295e1d | ||
|
|
ee62ef4745 | ||
|
|
ef6bafbf7e | ||
|
|
ed28b71e87 | ||
|
|
d47b7dc79a | ||
|
|
49b9709ce5 | ||
|
|
a2eba2cdf5 | ||
|
|
3d01b3cfe8 | ||
|
|
af2efa6f7e | ||
|
|
d73b61d367 | ||
|
|
d3533f81fc | ||
|
|
59a448b645 | ||
|
|
3de7a7f0cd | ||
|
|
4adb9eed77 | ||
|
|
b6a0f7a07f | ||
|
|
b2566368f8 | ||
|
|
1b2f907671 | ||
|
|
bda04eed8a | ||
|
|
e0735977b5 | ||
|
|
67985d8226 | ||
|
|
cbcb061812 | ||
|
|
9fc2e1b3c8 | ||
|
|
3b484aea9e | ||
|
|
963a0950fa | ||
|
|
1fb4f2b12e | ||
|
|
f4ba1ab910 | ||
|
|
2662f91082 | ||
|
|
f5967069f2 | ||
|
|
80f5523685 | ||
|
|
c1db2c7d7c | ||
|
|
5e5d8142f9 | ||
|
|
b01619b441 | ||
|
|
109cf3928a | ||
|
|
4794645dec | ||
|
|
f861bd6a94 | ||
|
|
6dbfdd140d | ||
|
|
aa8526edc0 | ||
|
|
ac3ca0ad8e | ||
|
|
386ccffed4 | ||
|
|
08d21b76e2 | ||
|
|
ffddd1c90a | ||
|
|
33aa665555 | ||
|
|
00280b6fe8 | ||
|
|
08e8fddf73 | ||
|
|
8f8dfd081b | ||
|
|
9f1b445c7c | ||
|
|
ae933dfe14 | ||
|
|
5d33d6b8ea | ||
|
|
e124db723b | ||
|
|
05444cf32d | ||
|
|
478aff1189 | ||
|
|
8edbda57cf | ||
|
|
52760a4eaa | ||
|
|
821249a5ed | ||
|
|
2331b9a2e7 | ||
|
|
ee33863b47 | ||
|
|
cd22c849e2 | ||
|
|
f0e73efda2 | ||
|
|
3156109c71 | ||
|
|
6762e081f3 | ||
|
|
5ca3508284 | ||
|
|
771fec9447 | ||
|
|
7815ee338d | ||
|
|
44b6c872e2 | ||
|
|
7a77b23f2d | ||
|
|
672e8549c0 | ||
|
|
66f5269a23 | ||
|
|
4eaf769894 | ||
|
|
ebec293497 | ||
|
|
e02ceecd35 | ||
|
|
9116392a45 | ||
|
|
c8b33a8cc3 | ||
|
|
dca8d5ded8 | ||
|
|
2a7fd1e897 | ||
|
|
b9d1e70ac2 | ||
|
|
fdf5720217 | ||
|
|
f40bd0cd51 | ||
|
|
e33676bb87 | ||
|
|
b1f1cee1e5 | ||
|
|
a1ecc9ab00 | ||
|
|
2a663d5cba | ||
|
|
ba486ca6b7 | ||
|
|
750b930679 | ||
|
|
3902fd7501 | ||
|
|
4fc3d5e935 | ||
|
|
2d2f4572a7 | ||
|
|
8f4c46f38d | ||
|
|
b6ba51bc2a | ||
|
|
6a66d32d37 | ||
|
|
8d15723195 | ||
|
|
736e0aae86 | ||
|
|
8bf3305b2b | ||
|
|
d00e3ea973 | ||
|
|
89db4e9481 | ||
|
|
e332419081 | ||
|
|
19232c6388 | ||
|
|
e998b1229a | ||
|
|
bbed134bd1 | ||
|
|
47b9503112 | ||
|
|
3b9253c2be | ||
|
|
d241359153 | ||
|
|
f4d4249ba5 | ||
|
|
06075c0e93 | ||
|
|
e85c9d9322 | ||
|
|
cb56cb250e | ||
|
|
e0381a6ae0 | ||
|
|
2c01b2ef64 | ||
|
|
e947266743 | ||
|
|
c6b0e85b54 | ||
|
|
26efbed05c | ||
|
|
96340bf136 | ||
|
|
b055e00c1a | ||
|
|
414db44c00 | ||
|
|
857c880f99 | ||
|
|
ce7474d953 | ||
|
|
70fdd70b84 | ||
|
|
08ab6a7d77 | ||
|
|
5c95129884 | ||
|
|
9fa2a7e9df | ||
|
|
d443c86620 | ||
|
|
7be3f1c36c | ||
|
|
f6ab6d97b9 | ||
|
|
bc866bac49 | ||
|
|
50e6d845f4 | ||
|
|
a8cb01819d | ||
|
|
1a99cfded4 | ||
|
|
530273906b | ||
|
|
06ddf575d9 | ||
|
|
cf369d4684 | ||
|
|
3099114cbb | ||
|
|
44b63f0767 | ||
|
|
6705d20194 | ||
|
|
a38a9c0b0f | ||
|
|
8286caa366 | ||
|
|
bd1ec8424d | ||
|
|
225e2c6797 | ||
|
|
d8fc485513 | ||
|
|
f137eb0ac4 | ||
|
|
f39a460487 | ||
|
|
ee171bc563 | ||
|
|
a95428f204 | ||
|
|
cb3bdffb43 | ||
|
|
48f19aab51 | ||
|
|
48f6d7abdf | ||
|
|
79fbcb3ec4 | ||
|
|
0e4148b229 | ||
|
|
3ca5fb1046 | ||
|
|
a091d12f4e | ||
|
|
e3d8d726e6 | ||
|
|
457924828a | ||
|
|
aca2ef6359 | ||
|
|
ade7194792 | ||
|
|
0f51e73baa | ||
|
|
3a436e116a | ||
|
|
d06e2dc83c | ||
|
|
336867853b | ||
|
|
6403ff4ec4 | ||
|
|
d222469b44 | ||
|
|
790a17ce98 | ||
|
|
d473c952fb | ||
|
|
7646a2b877 | ||
|
|
62090f2568 | ||
|
|
d35152bbef | ||
|
|
c281f4cbaf | ||
|
|
09455f9e85 | ||
|
|
c8e72ba0dc | ||
|
|
375ef252ab | ||
|
|
ee552f8720 | ||
|
|
2e88c4858e | ||
|
|
3f50da85c1 | ||
|
|
8be06255f7 | ||
|
|
60936b5185 | ||
|
|
72274099aa | ||
|
|
b7f7b3a1d8 | ||
|
|
dcae098e23 | ||
|
|
618606966f | ||
|
|
05f249d77f | ||
|
|
2eb05ec640 | ||
|
|
3ce0d76aa4 | ||
|
|
a00b79d9be | ||
|
|
9fe6a215e6 | ||
|
|
33e53a2a56 | ||
|
|
cd5b80785f | ||
|
|
54f71aa273 | ||
|
|
3f949b7f84 | ||
|
|
cf8b2dcc85 | ||
|
|
8e24d9dc34 | ||
|
|
443c4538bb | ||
|
|
a7fc2ee4cf | ||
|
|
8e749ac22d | ||
|
|
69e09d9bc7 | ||
|
|
ed57d82bc1 | ||
|
|
06ad527e8c | ||
|
|
7af5a90a0b | ||
|
|
7551faff79 | ||
|
|
b7409dd2de | ||
|
|
5ba325a8fc | ||
|
|
d502840f91 | ||
|
|
99238a4b59 | ||
|
|
6d43a2ff9a | ||
|
|
cdb9c2e6e8 | ||
|
|
3faa1ca9af | ||
|
|
9d975e0375 | ||
|
|
2a6d8b78d4 | ||
|
|
671558a822 | ||
|
|
6b80ec79a0 | ||
|
|
d3f4783a24 | ||
|
|
1cb6bdbc87 | ||
|
|
96ddfc1f24 | ||
|
|
c169b32570 | ||
|
|
36a512fdf2 | ||
|
|
26fbb77901 | ||
|
|
a277302262 | ||
|
|
969c1a5b72 | ||
|
|
872339bceb | ||
|
|
5dc0dbc7aa | ||
|
|
ee6fc4e8a1 | ||
|
|
8fee16aecd | ||
|
|
2b7ba54a2f | ||
|
|
007c3304f2 | ||
|
|
e76ba0ede9 | ||
|
|
c06ac07e23 | ||
|
|
e592a57458 | ||
|
|
66769ec657 | ||
|
|
f413feec61 | ||
|
|
2e538e3486 | ||
|
|
9617a7b0d6 | ||
|
|
7569320770 | ||
|
|
8d25cf0d75 | ||
|
|
64e85e7019 | ||
|
|
a862984dca | ||
|
|
f0365f0465 | ||
|
|
6d1e20e940 | ||
|
|
0c0aae1eac | ||
|
|
5dcf7cb846 | ||
|
|
349b2ba3af | ||
|
|
98db5aabd0 | ||
|
|
e52b542e22 | ||
|
|
8f6abb8a86 | ||
|
|
ed8eaae964 | ||
|
|
7fd98f3556 | ||
|
|
e8de87ee90 | ||
|
|
4e572ec8b9 | ||
|
|
6c7f18c448 | ||
|
|
24bc9cba67 | ||
|
|
97356b1a04 | ||
|
|
1084b53fba | ||
|
|
b1aecc2bf1 | ||
|
|
83b90e106f | ||
|
|
f52114dab2 | ||
|
|
5106caf641 | ||
|
|
12370ee84e | ||
|
|
b84ccc6e7a | ||
|
|
e19ddb53e7 | ||
|
|
5bf89dd757 | ||
|
|
2a0100b2d6 | ||
|
|
4442574e53 | ||
|
|
c020fa60d0 | ||
|
|
b078be4613 | ||
|
|
71a6dffbb6 | ||
|
|
5f65dd5bb4 | ||
|
|
27b43ed63f | ||
|
|
f6a3a1d0ba | ||
|
|
830fd8eac2 | ||
|
|
a86d501dc2 | ||
|
|
24e8e20b59 | ||
|
|
e755e567ea | ||
|
|
a87f09bad2 | ||
|
|
dbcbe48ead | ||
|
|
63908869f6 | ||
|
|
db491c8f9b | ||
|
|
f6d625114c | ||
|
|
7dc40ba6d4 | ||
|
|
fcd6475377 | ||
|
|
4070c9de81 | ||
|
|
1e9e4a86a2 | ||
|
|
406a27271a | ||
|
|
9f9a4fc2af | ||
|
|
3fc410a253 | ||
|
|
781bc1521b | ||
|
|
05d201ece8 | ||
|
|
cd0c94f48a | ||
|
|
293cc8c1a3 | ||
|
|
453e744abf | ||
|
|
653439698e | ||
|
|
24970baa57 | ||
|
|
5418bbc338 | ||
|
|
89254cfc97 | ||
|
|
6bd9a034f7 | ||
|
|
26fc65b051 | ||
|
|
ed5ec5b55c | ||
|
|
df777650ac | ||
|
|
9855615f1e | ||
|
|
93414f1baa | ||
|
|
8fac6b147a | ||
|
|
10f8c795ac | ||
|
|
3e4858a624 | ||
|
|
1231dc9cda | ||
|
|
c84ff42bcd | ||
|
|
40d78908ed | ||
|
|
8a5db02165 | ||
|
|
56fa81f3c6 | ||
|
|
d7afb6eb0c | ||
|
|
03209b35c0 | ||
|
|
bbd1fe890a | ||
|
|
843316ea7a | ||
|
|
f607231efa | ||
|
|
2039062845 | ||
|
|
44f66d2257 | ||
|
|
99478d13a8 | ||
|
|
3b51a0fe12 | ||
|
|
2d91c2a3f5 | ||
|
|
bc6c4cdbfc | ||
|
|
69d3a80fc3 | ||
|
|
404546ce93 | ||
|
|
9e268ad103 | ||
|
|
6dd1cf1dd6 | ||
|
|
9058d406a3 | ||
|
|
9d9b9e7a0d | ||
|
|
13aa82f3f3 | ||
|
|
4ea5586b6f | ||
|
|
05e55d7dc5 | ||
|
|
1b358c931c | ||
|
|
e04b02113a | ||
|
|
3275494fde | ||
|
|
e3af8783b9 | ||
|
|
ca09db21ff | ||
|
|
c1f8211acb | ||
|
|
718ff7a73f | ||
|
|
fa70b220e9 | ||
|
|
98fa2a1597 | ||
|
|
0e7c79ba23 | ||
|
|
b6ba15fcbd | ||
|
|
e44167d7a4 | ||
|
|
1bfa75f780 | ||
|
|
bbcb5552f3 | ||
|
|
31bd90c748 | ||
|
|
1b8cb7b77b | ||
|
|
774f1fbc17 | ||
|
|
cfa8ddb59f | ||
|
|
39597267ae | ||
|
|
393e38f2c0 | ||
|
|
0f646800f6 | ||
|
|
ca993238f3 | ||
|
|
d1220de02d | ||
|
|
cf9a246d53 | ||
|
|
13eb5268de | ||
|
|
88798816f2 | ||
|
|
598f0af19b | ||
|
|
a33f5d31fc | ||
|
|
54acd69e9d | ||
|
|
d687ee2777 | ||
|
|
54c2fefbad | ||
|
|
506699fba1 | ||
|
|
f7b17ee6ec | ||
|
|
408614c74c | ||
|
|
68a27772b3 | ||
|
|
de87fb622b | ||
|
|
0155a01bb1 | ||
|
|
cfeee5d511 | ||
|
|
f27672f6cf | ||
|
|
28420c14e4 | ||
|
|
10e0ea1309 | ||
|
|
0bd221ff41 | ||
|
|
5fda6f8ef3 | ||
|
|
9b956f6338 | ||
|
|
09923f654c | ||
|
|
ae7b972649 | ||
|
|
47885e3710 | ||
|
|
4b9a260b37 | ||
|
|
462a70541e | ||
|
|
2407c1f4af | ||
|
|
2c743c8f0b | ||
|
|
9f2c278ee6 | ||
|
|
aea337cfe2 | ||
|
|
811f8f8b4f | ||
|
|
27734a23b1 | ||
|
|
1b8e538a77 | ||
|
|
41c2385aca | ||
|
|
d605985f45 | ||
|
|
d52b28b147 | ||
|
|
4afe1f42ca | ||
|
|
7481c0eaa0 | ||
|
|
024bc25b2c | ||
|
|
ffdfad8482 | ||
|
|
b91ee8d008 | ||
|
|
6586f08584 | ||
|
|
92c62bb2fb | ||
|
|
f49e887fe6 | ||
|
|
344066fd11 | ||
|
|
bcb8092488 | ||
|
|
1efade8bdb | ||
|
|
a5b3ff11fd | ||
|
|
084558f200 | ||
|
|
b602eae215 | ||
|
|
d02bf9c243 | ||
|
|
26a5f67df2 | ||
|
|
600fd42a83 | ||
|
|
670685139a | ||
|
|
52b6306388 | ||
|
|
f957b8948c | ||
|
|
cd0b14dd2d | ||
|
|
894703a484 | ||
|
|
521ec6f1b8 | ||
|
|
b0c5d9640a | ||
|
|
ef8e94e992 | ||
|
|
9df96a4bb4 | ||
|
|
28a428ae2f | ||
|
|
b326ec3641 | ||
|
|
fcecbc7d46 | ||
|
|
f4007f53ba | ||
|
|
d08a2453f7 | ||
|
|
3f53eea1e0 | ||
|
|
5a812a1e93 | ||
|
|
5e624cc7b1 | ||
|
|
f3d1cc8dc1 | ||
|
|
e889efeda7 | ||
|
|
0a3a95521c | ||
|
|
4ebaf6f7a9 | ||
|
|
59ac1a3f60 | ||
|
|
3af24597ee | ||
|
|
0b834fcb54 | ||
|
|
e0be6c5786 | ||
|
|
88b101ebf5 | ||
|
|
923a5d6efb | ||
|
|
734b7e42ad | ||
|
|
d9a65745df | ||
|
|
97ab623d42 | ||
|
|
14aa6cc7e8 | ||
|
|
10e77fcf24 | ||
|
|
bbb21d7c2b | ||
|
|
3bc489254b | ||
|
|
4c07ea41c3 | ||
|
|
f6720f8dfa | ||
|
|
e19ab3a066 | ||
|
|
c46099c5d7 | ||
|
|
8f1dd69e72 | ||
|
|
f26da24a2f | ||
|
|
407020de0c | ||
|
|
8e4fbcaa7d | ||
|
|
09c339953d | ||
|
|
367a05bdf6 | ||
|
|
d20b71deb9 | ||
|
|
712ce9f781 | ||
|
|
a4a3274a55 | ||
|
|
716aa71f6e | ||
|
|
e8976f9898 | ||
|
|
8496cc2444 | ||
|
|
5ef2d59e05 | ||
|
|
07bb89ae80 | ||
|
|
27a5ad8ec2 | ||
|
|
707b07c5f5 | ||
|
|
4a764afd76 | ||
|
|
ecf49d574b | ||
|
|
188de4ff2a | ||
|
|
5a75ef8ffd | ||
|
|
07279f8746 | ||
|
|
71f788b13a | ||
|
|
59c62dc580 |
@@ -13,8 +13,6 @@ Dockerfile
|
||||
docs/*
|
||||
README.md
|
||||
README_CN.md
|
||||
MANAGEMENT_API.md
|
||||
MANAGEMENT_API_CN.md
|
||||
LICENSE
|
||||
|
||||
# Runtime data folders (should be mounted as volumes)
|
||||
@@ -25,6 +23,14 @@ config.yaml
|
||||
|
||||
# Development/editor
|
||||
bin/*
|
||||
.claude/*
|
||||
.vscode/*
|
||||
.claude/*
|
||||
.codex/*
|
||||
.gemini/*
|
||||
.serena/*
|
||||
.agent/*
|
||||
.agents/*
|
||||
.opencode/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
_bmad-output/*
|
||||
|
||||
7
.github/ISSUE_TEMPLATE/bug_report.md
vendored
7
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -7,6 +7,13 @@ assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is it a request payload issue?**
|
||||
[ ] Yes, this is a request payload issue. I am using a client/cURL to send a request payload, but I received an unexpected error.
|
||||
[ ] No, it's another issue.
|
||||
|
||||
**If it's a request payload issue, you MUST know**
|
||||
Our team doesn't have any GODs or ORACLEs or MIND READERs. Please make sure to attach the request log or curl payload.
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
|
||||
1
.github/workflows/docker-image.yml
vendored
1
.github/workflows/docker-image.yml
vendored
@@ -44,3 +44,4 @@ jobs:
|
||||
tags: |
|
||||
${{ env.DOCKERHUB_REPO }}:latest
|
||||
${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }}
|
||||
|
||||
|
||||
23
.github/workflows/pr-test-build.yml
vendored
Normal file
23
.github/workflows/pr-test-build.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: pr-test-build
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: Build
|
||||
run: |
|
||||
go build -o test-output ./cmd/server
|
||||
rm -f test-output
|
||||
16
.gitignore
vendored
16
.gitignore
vendored
@@ -12,11 +12,15 @@ bin/*
|
||||
logs/*
|
||||
conv/*
|
||||
temp/*
|
||||
refs/*
|
||||
|
||||
# Storage backends
|
||||
pgstore/*
|
||||
gitstore/*
|
||||
objectstore/*
|
||||
|
||||
# Static assets
|
||||
static/*
|
||||
refs/*
|
||||
|
||||
# Authentication data
|
||||
auths/*
|
||||
@@ -30,10 +34,20 @@ GEMINI.md
|
||||
|
||||
# Tooling metadata
|
||||
.vscode/*
|
||||
.codex/*
|
||||
.claude/*
|
||||
.gemini/*
|
||||
.serena/*
|
||||
.agent/*
|
||||
.agents/*
|
||||
.agents/*
|
||||
.opencode/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
_bmad-output/*
|
||||
.mcp/cache/
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
._*
|
||||
*.bak
|
||||
|
||||
76
README.md
76
README.md
@@ -13,6 +13,82 @@ The Plus release stays in lockstep with the mainline features.
|
||||
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
|
||||
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)
|
||||
|
||||
## New Features (Plus Enhanced)
|
||||
|
||||
- **OAuth Web Authentication**: Browser-based OAuth login for Kiro with beautiful web UI
|
||||
- **Rate Limiter**: Built-in request rate limiting to prevent API abuse
|
||||
- **Background Token Refresh**: Automatic token refresh 10 minutes before expiration
|
||||
- **Metrics & Monitoring**: Request metrics collection for monitoring and debugging
|
||||
- **Device Fingerprint**: Device fingerprint generation for enhanced security
|
||||
- **Cooldown Management**: Smart cooldown mechanism for API rate limits
|
||||
- **Usage Checker**: Real-time usage monitoring and quota management
|
||||
- **Model Converter**: Unified model name conversion across providers
|
||||
- **UTF-8 Stream Processing**: Improved streaming response handling
|
||||
|
||||
## Kiro Authentication
|
||||
|
||||
### Web-based OAuth Login
|
||||
|
||||
Access the Kiro OAuth web interface at:
|
||||
|
||||
```
|
||||
http://your-server:8080/v0/oauth/kiro
|
||||
```
|
||||
|
||||
This provides a browser-based OAuth flow for Kiro (AWS CodeWhisperer) authentication with:
|
||||
- AWS Builder ID login
|
||||
- AWS Identity Center (IDC) login
|
||||
- Token import from Kiro IDE
|
||||
|
||||
## Quick Deployment with Docker
|
||||
|
||||
### One-Command Deployment
|
||||
|
||||
```bash
|
||||
# Create deployment directory
|
||||
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
||||
|
||||
# Create docker-compose.yml
|
||||
cat > docker-compose.yml << 'EOF'
|
||||
services:
|
||||
cli-proxy-api:
|
||||
image: 17600006524/cli-proxy-api-plus:latest
|
||||
container_name: cli-proxy-api-plus
|
||||
ports:
|
||||
- "8317:8317"
|
||||
volumes:
|
||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
||||
- ./auths:/root/.cli-proxy-api
|
||||
- ./logs:/CLIProxyAPI/logs
|
||||
restart: unless-stopped
|
||||
EOF
|
||||
|
||||
# Download example config
|
||||
curl -o config.yaml https://raw.githubusercontent.com/linlang781/CLIProxyAPIPlus/main/config.example.yaml
|
||||
|
||||
# Pull and start
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
Edit `config.yaml` before starting:
|
||||
|
||||
```yaml
|
||||
# Basic configuration example
|
||||
server:
|
||||
port: 8317
|
||||
|
||||
# Add your provider configurations here
|
||||
```
|
||||
|
||||
### Update to Latest Version
|
||||
|
||||
```bash
|
||||
cd ~/cli-proxy
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.
|
||||
|
||||
76
README_CN.md
76
README_CN.md
@@ -13,6 +13,82 @@
|
||||
- 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
|
||||
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供
|
||||
|
||||
## 新增功能 (Plus 增强版)
|
||||
|
||||
- **OAuth Web 认证**: 基于浏览器的 Kiro OAuth 登录,提供美观的 Web UI
|
||||
- **请求限流器**: 内置请求限流,防止 API 滥用
|
||||
- **后台令牌刷新**: 过期前 10 分钟自动刷新令牌
|
||||
- **监控指标**: 请求指标收集,用于监控和调试
|
||||
- **设备指纹**: 设备指纹生成,增强安全性
|
||||
- **冷却管理**: 智能冷却机制,应对 API 速率限制
|
||||
- **用量检查器**: 实时用量监控和配额管理
|
||||
- **模型转换器**: 跨供应商的统一模型名称转换
|
||||
- **UTF-8 流处理**: 改进的流式响应处理
|
||||
|
||||
## Kiro 认证
|
||||
|
||||
### 网页端 OAuth 登录
|
||||
|
||||
访问 Kiro OAuth 网页认证界面:
|
||||
|
||||
```
|
||||
http://your-server:8080/v0/oauth/kiro
|
||||
```
|
||||
|
||||
提供基于浏览器的 Kiro (AWS CodeWhisperer) OAuth 认证流程,支持:
|
||||
- AWS Builder ID 登录
|
||||
- AWS Identity Center (IDC) 登录
|
||||
- 从 Kiro IDE 导入令牌
|
||||
|
||||
## Docker 快速部署
|
||||
|
||||
### 一键部署
|
||||
|
||||
```bash
|
||||
# 创建部署目录
|
||||
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
||||
|
||||
# 创建 docker-compose.yml
|
||||
cat > docker-compose.yml << 'EOF'
|
||||
services:
|
||||
cli-proxy-api:
|
||||
image: 17600006524/cli-proxy-api-plus:latest
|
||||
container_name: cli-proxy-api-plus
|
||||
ports:
|
||||
- "8317:8317"
|
||||
volumes:
|
||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
||||
- ./auths:/root/.cli-proxy-api
|
||||
- ./logs:/CLIProxyAPI/logs
|
||||
restart: unless-stopped
|
||||
EOF
|
||||
|
||||
# 下载示例配置
|
||||
curl -o config.yaml https://raw.githubusercontent.com/linlang781/CLIProxyAPIPlus/main/config.example.yaml
|
||||
|
||||
# 拉取并启动
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
### 配置说明
|
||||
|
||||
启动前请编辑 `config.yaml`:
|
||||
|
||||
```yaml
|
||||
# 基本配置示例
|
||||
server:
|
||||
port: 8317
|
||||
|
||||
# 在此添加你的供应商配置
|
||||
```
|
||||
|
||||
### 更新到最新版本
|
||||
|
||||
```bash
|
||||
cd ~/cli-proxy
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
## 贡献
|
||||
|
||||
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
|
||||
|
||||
BIN
assets/cubence.png
Normal file
BIN
assets/cubence.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 51 KiB |
BIN
assets/packycode.png
Normal file
BIN
assets/packycode.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.1 KiB |
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cmd"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -74,10 +75,12 @@ func main() {
|
||||
var iflowLogin bool
|
||||
var iflowCookie bool
|
||||
var noBrowser bool
|
||||
var oauthCallbackPort int
|
||||
var antigravityLogin bool
|
||||
var kiroLogin bool
|
||||
var kiroGoogleLogin bool
|
||||
var kiroAWSLogin bool
|
||||
var kiroAWSAuthCode bool
|
||||
var kiroImport bool
|
||||
var githubCopilotLogin bool
|
||||
var projectID string
|
||||
@@ -95,12 +98,14 @@ func main() {
|
||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
||||
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
||||
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
||||
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
||||
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
||||
flag.BoolVar(&kiroAWSAuthCode, "kiro-aws-authcode", false, "Login to Kiro using AWS Builder ID (authorization code flow, better UX)")
|
||||
flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)")
|
||||
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||
@@ -432,7 +437,7 @@ func main() {
|
||||
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
||||
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
|
||||
if err = logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
||||
if err = logging.ConfigureLogOutput(cfg); err != nil {
|
||||
log.Errorf("failed to configure log output: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -452,7 +457,8 @@ func main() {
|
||||
|
||||
// Create login options to be used in authentication flows.
|
||||
options := &cmd.LoginOptions{
|
||||
NoBrowser: noBrowser,
|
||||
NoBrowser: noBrowser,
|
||||
CallbackPort: oauthCallbackPort,
|
||||
}
|
||||
|
||||
// Register the shared token store once so all components use the same persistence backend.
|
||||
@@ -513,6 +519,10 @@ func main() {
|
||||
// Users can explicitly override with --no-incognito
|
||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||
cmd.DoKiroAWSLogin(cfg, options)
|
||||
} else if kiroAWSAuthCode {
|
||||
// For Kiro auth with authorization code flow (better UX)
|
||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||
cmd.DoKiroAWSAuthCodeLogin(cfg, options)
|
||||
} else if kiroImport {
|
||||
cmd.DoKiroImport(cfg, options)
|
||||
} else {
|
||||
@@ -524,6 +534,13 @@ func main() {
|
||||
}
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
|
||||
// 初始化并启动 Kiro token 后台刷新
|
||||
if cfg.AuthDir != "" {
|
||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||
defer kiro.StopGlobalRefreshManager()
|
||||
}
|
||||
|
||||
cmd.StartService(cfg, configFilePath, password)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,9 @@ remote-management:
|
||||
# Disable the bundled management control panel asset download and HTTP route when true.
|
||||
disable-control-panel: false
|
||||
|
||||
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
|
||||
panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
|
||||
# Authentication directory (supports ~ for home directory)
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
@@ -32,10 +35,14 @@ auth-dir: "~/.cli-proxy-api"
|
||||
api-keys:
|
||||
- "your-api-key-1"
|
||||
- "your-api-key-2"
|
||||
- "your-api-key-3"
|
||||
|
||||
# Enable debug logging
|
||||
debug: false
|
||||
|
||||
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
|
||||
commercial-mode: false
|
||||
|
||||
# Open OAuth URLs in incognito/private browser mode.
|
||||
# Useful when you want to login with a different account without logging out from your current session.
|
||||
# Default: false (but Kiro auth defaults to true for multi-account support)
|
||||
@@ -44,12 +51,19 @@ incognito-browser: true
|
||||
# When true, write application logs to rotating files instead of stdout
|
||||
logging-to-file: false
|
||||
|
||||
# Maximum total size (MB) of log files under the logs directory. When exceeded, the oldest log
|
||||
# files are deleted until within the limit. Set to 0 to disable.
|
||||
logs-max-total-size-mb: 0
|
||||
|
||||
# When false, disable in-memory usage statistics aggregation
|
||||
usage-statistics-enabled: false
|
||||
|
||||
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
||||
proxy-url: ""
|
||||
|
||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||
force-model-prefix: false
|
||||
|
||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||
request-retry: 3
|
||||
|
||||
@@ -61,16 +75,36 @@ quota-exceeded:
|
||||
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
||||
|
||||
# Routing strategy for selecting credentials when multiple match.
|
||||
routing:
|
||||
strategy: "round-robin" # round-robin (default), fill-first
|
||||
|
||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||
ws-auth: false
|
||||
|
||||
# When > 0, emit blank lines every N seconds for non-streaming responses to prevent idle timeouts.
|
||||
nonstream-keepalive-interval: 0
|
||||
|
||||
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
|
||||
# streaming:
|
||||
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
|
||||
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
|
||||
|
||||
# When true, enable official Codex instructions injection for Codex API requests.
|
||||
# When false (default), CodexInstructionsForModel returns immediately without modification.
|
||||
codex-instructions-enabled: false
|
||||
|
||||
# Gemini API keys
|
||||
# gemini-api-key:
|
||||
# - api-key: "AIzaSy...01"
|
||||
# prefix: "test" # optional: require calls like "test/gemini-3-pro-preview" to target this credential
|
||||
# base-url: "https://generativelanguage.googleapis.com"
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080"
|
||||
# models:
|
||||
# - name: "gemini-2.5-flash" # upstream model name
|
||||
# alias: "gemini-flash" # client alias mapped to the upstream model
|
||||
# excluded-models:
|
||||
# - "gemini-2.5-pro" # exclude specific models from this provider (exact match)
|
||||
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
|
||||
@@ -81,10 +115,14 @@ ws-auth: false
|
||||
# Codex API keys
|
||||
# codex-api-key:
|
||||
# - api-key: "sk-atSM..."
|
||||
# prefix: "test" # optional: require calls like "test/gpt-5-codex" to target this credential
|
||||
# base-url: "https://www.example.com" # use the custom codex API endpoint
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# models:
|
||||
# - name: "gpt-5-codex" # upstream model name
|
||||
# alias: "codex-latest" # client alias mapped to the upstream model
|
||||
# excluded-models:
|
||||
# - "gpt-5.1" # exclude specific models (exact match)
|
||||
# - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex)
|
||||
@@ -95,18 +133,28 @@ ws-auth: false
|
||||
# claude-api-key:
|
||||
# - api-key: "sk-atSM..." # use the official claude API key, no need to set the base url
|
||||
# - api-key: "sk-atSM..."
|
||||
# prefix: "test" # optional: require calls like "test/claude-sonnet-latest" to target this credential
|
||||
# base-url: "https://www.example.com" # use the custom claude API endpoint
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# models:
|
||||
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||
# excluded-models:
|
||||
# - "claude-opus-4-5-20251101" # exclude specific models (exact match)
|
||||
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
|
||||
# - "*-thinking" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
|
||||
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
|
||||
# cloak: # optional: request cloaking for non-Claude-Code clients
|
||||
# mode: "auto" # "auto" (default): cloak only when client is not Claude Code
|
||||
# # "always": always apply cloaking
|
||||
# # "never": never apply cloaking
|
||||
# strict-mode: false # false (default): prepend Claude Code prompt to user system messages
|
||||
# # true: strip all user system messages, keep only Claude Code prompt
|
||||
# sensitive-words: # optional: words to obfuscate with zero-width characters
|
||||
# - "API"
|
||||
# - "proxy"
|
||||
|
||||
# Kiro (AWS CodeWhisperer) configuration
|
||||
# Note: Kiro API currently only operates in us-east-1 region
|
||||
@@ -121,6 +169,7 @@ ws-auth: false
|
||||
# OpenAI compatibility providers
|
||||
# openai-compatibility:
|
||||
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
|
||||
# prefix: "test" # optional: require calls like "test/kimi-k2" to target this provider's credentials
|
||||
# base-url: "https://openrouter.ai/api/v1" # The base URL of the provider.
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -135,14 +184,15 @@ ws-auth: false
|
||||
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
||||
# vertex-api-key:
|
||||
# - api-key: "vk-123..." # x-goog-api-key header
|
||||
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
||||
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# models: # optional: map aliases to upstream model names
|
||||
# - name: "gemini-2.0-flash" # upstream model name
|
||||
# - name: "gemini-2.5-flash" # upstream model name
|
||||
# alias: "vertex-flash" # client-visible alias
|
||||
# - name: "gemini-1.5-pro"
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "vertex-pro"
|
||||
|
||||
# Amp Integration
|
||||
@@ -151,8 +201,20 @@ ws-auth: false
|
||||
# upstream-url: "https://ampcode.com"
|
||||
# # Optional: Override API key for Amp upstream (otherwise uses env or file)
|
||||
# upstream-api-key: ""
|
||||
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended)
|
||||
# restrict-management-to-localhost: true
|
||||
# # Per-client upstream API key mapping
|
||||
# # Maps client API keys (from top-level api-keys) to different Amp upstream API keys.
|
||||
# # Useful when different clients need to use different Amp accounts/quotas.
|
||||
# # If a client key isn't mapped, falls back to upstream-api-key (default behavior).
|
||||
# upstream-api-keys:
|
||||
# - upstream-api-key: "amp_key_for_team_a" # Upstream key to use for these clients
|
||||
# api-keys: # Client keys that use this upstream key
|
||||
# - "your-api-key-1"
|
||||
# - "your-api-key-2"
|
||||
# - upstream-api-key: "amp_key_for_team_b"
|
||||
# api-keys:
|
||||
# - "your-api-key-3"
|
||||
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
|
||||
# restrict-management-to-localhost: false
|
||||
# # Force model mappings to run before checking local API keys (default: false)
|
||||
# force-model-mappings: false
|
||||
# # Amp Model Mappings
|
||||
@@ -160,14 +222,68 @@ ws-auth: false
|
||||
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
|
||||
# # but you have a similar model available (e.g., Claude Sonnet 4).
|
||||
# model-mappings:
|
||||
# - from: "claude-opus-4.5" # Model requested by Amp CLI
|
||||
# to: "claude-sonnet-4" # Route to this available model instead
|
||||
# - from: "gpt-5"
|
||||
# to: "gemini-2.5-pro"
|
||||
# - from: "claude-3-opus-20240229"
|
||||
# to: "claude-3-5-sonnet-20241022"
|
||||
# - from: "claude-opus-4-5-20251101" # Model requested by Amp CLI
|
||||
# to: "gemini-claude-opus-4-5-thinking" # Route to this available model instead
|
||||
# - from: "claude-sonnet-4-5-20250929"
|
||||
# to: "gemini-claude-sonnet-4-5-thinking"
|
||||
# - from: "claude-haiku-4-5-20251001"
|
||||
# to: "gemini-2.5-flash"
|
||||
|
||||
# Global OAuth model name aliases (per channel)
|
||||
# These aliases rename model IDs for both model listing and request routing.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
|
||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||
#oauth-model-alias:
|
||||
# antigravity:
|
||||
# - name: "rev19-uic3-1p"
|
||||
# alias: "gemini-2.5-computer-use-preview-10-2025"
|
||||
# - name: "gemini-3-pro-image"
|
||||
# alias: "gemini-3-pro-image-preview"
|
||||
# - name: "gemini-3-pro-high"
|
||||
# alias: "gemini-3-pro-preview"
|
||||
# - name: "gemini-3-flash"
|
||||
# alias: "gemini-3-flash-preview"
|
||||
# - name: "claude-sonnet-4-5"
|
||||
# alias: "gemini-claude-sonnet-4-5"
|
||||
# - name: "claude-sonnet-4-5-thinking"
|
||||
# alias: "gemini-claude-sonnet-4-5-thinking"
|
||||
# - name: "claude-opus-4-5-thinking"
|
||||
# alias: "gemini-claude-opus-4-5-thinking"
|
||||
# gemini-cli:
|
||||
# - name: "gemini-2.5-pro" # original model name under this channel
|
||||
# alias: "g2.5p" # client-visible alias
|
||||
# fork: true # when true, keep original and also add the alias as an extra model (default: false)
|
||||
# vertex:
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "g2.5p"
|
||||
# aistudio:
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "g2.5p"
|
||||
# antigravity:
|
||||
# - name: "gemini-3-pro-preview"
|
||||
# alias: "g3p"
|
||||
# claude:
|
||||
# - name: "claude-sonnet-4-5-20250929"
|
||||
# alias: "cs4.5"
|
||||
# codex:
|
||||
# - name: "gpt-5"
|
||||
# alias: "g5"
|
||||
# qwen:
|
||||
# - name: "qwen3-coder-plus"
|
||||
# alias: "qwen-plus"
|
||||
# iflow:
|
||||
# - name: "glm-4.7"
|
||||
# alias: "glm-god"
|
||||
# kiro:
|
||||
# - name: "kiro-claude-opus-4-5"
|
||||
# alias: "op45"
|
||||
# github-copilot:
|
||||
# - name: "gpt-5"
|
||||
# alias: "copilot-gpt5"
|
||||
|
||||
# OAuth provider excluded models
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
|
||||
# oauth-excluded-models:
|
||||
# gemini-cli:
|
||||
# - "gemini-2.5-pro" # exclude specific models (exact match)
|
||||
@@ -188,6 +304,10 @@ ws-auth: false
|
||||
# - "vision-model"
|
||||
# iflow:
|
||||
# - "tstars2.0"
|
||||
# kiro:
|
||||
# - "kiro-claude-haiku-4-5"
|
||||
# github-copilot:
|
||||
# - "raptor-mini"
|
||||
|
||||
# Optional payload configuration
|
||||
# payload:
|
||||
@@ -197,9 +317,21 @@ ws-auth: false
|
||||
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
|
||||
# params: # JSON path (gjson/sjson syntax) -> value
|
||||
# "generationConfig.thinkingConfig.thinkingBudget": 32768
|
||||
# default-raw: # Default raw rules set parameters using raw JSON when missing (must be valid JSON).
|
||||
# - models:
|
||||
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
|
||||
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
|
||||
# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON)
|
||||
# "generationConfig.responseJsonSchema": "{\"type\":\"object\",\"properties\":{\"answer\":{\"type\":\"string\"}}}"
|
||||
# override: # Override rules always set parameters, overwriting any existing values.
|
||||
# - models:
|
||||
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
|
||||
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
|
||||
# params: # JSON path (gjson/sjson syntax) -> value
|
||||
# "reasoning.effort": "high"
|
||||
# override-raw: # Override raw rules always set parameters using raw JSON (must be valid JSON).
|
||||
# - models:
|
||||
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
|
||||
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
|
||||
# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON)
|
||||
# "response_format": "{\"type\":\"json_schema\",\"json_schema\":{\"name\":\"answer\",\"schema\":{\"type\":\"object\"}}}"
|
||||
|
||||
128
docker-build.sh
128
docker-build.sh
@@ -5,9 +5,115 @@
|
||||
# This script automates the process of building and running the Docker container
|
||||
# with version information dynamically injected at build time.
|
||||
|
||||
# Exit immediately if a command exits with a non-zero status.
|
||||
# Hidden feature: Preserve usage statistics across rebuilds
|
||||
# Usage: ./docker-build.sh --with-usage
|
||||
# First run prompts for management API key, saved to temp/stats/.api_secret
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
STATS_DIR="temp/stats"
|
||||
STATS_FILE="${STATS_DIR}/.usage_backup.json"
|
||||
SECRET_FILE="${STATS_DIR}/.api_secret"
|
||||
WITH_USAGE=false
|
||||
|
||||
get_port() {
|
||||
if [[ -f "config.yaml" ]]; then
|
||||
grep -E "^port:" config.yaml | sed -E 's/^port: *["'"'"']?([0-9]+)["'"'"']?.*$/\1/'
|
||||
else
|
||||
echo "8317"
|
||||
fi
|
||||
}
|
||||
|
||||
export_stats_api_secret() {
|
||||
if [[ -f "${SECRET_FILE}" ]]; then
|
||||
API_SECRET=$(cat "${SECRET_FILE}")
|
||||
else
|
||||
if [[ ! -d "${STATS_DIR}" ]]; then
|
||||
mkdir -p "${STATS_DIR}"
|
||||
fi
|
||||
echo "First time using --with-usage. Management API key required."
|
||||
read -r -p "Enter management key: " -s API_SECRET
|
||||
echo
|
||||
echo "${API_SECRET}" > "${SECRET_FILE}"
|
||||
chmod 600 "${SECRET_FILE}"
|
||||
fi
|
||||
}
|
||||
|
||||
check_container_running() {
|
||||
local port
|
||||
port=$(get_port)
|
||||
|
||||
if ! curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
|
||||
echo "Error: cli-proxy-api service is not responding at localhost:${port}"
|
||||
echo "Please start the container first or use without --with-usage flag."
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
export_stats() {
|
||||
local port
|
||||
port=$(get_port)
|
||||
|
||||
if [[ ! -d "${STATS_DIR}" ]]; then
|
||||
mkdir -p "${STATS_DIR}"
|
||||
fi
|
||||
check_container_running
|
||||
echo "Exporting usage statistics..."
|
||||
EXPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -H "X-Management-Key: ${API_SECRET}" \
|
||||
"http://localhost:${port}/v0/management/usage/export")
|
||||
HTTP_CODE=$(echo "${EXPORT_RESPONSE}" | tail -n1)
|
||||
RESPONSE_BODY=$(echo "${EXPORT_RESPONSE}" | sed '$d')
|
||||
|
||||
if [[ "${HTTP_CODE}" != "200" ]]; then
|
||||
echo "Export failed (HTTP ${HTTP_CODE}): ${RESPONSE_BODY}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "${RESPONSE_BODY}" > "${STATS_FILE}"
|
||||
echo "Statistics exported to ${STATS_FILE}"
|
||||
}
|
||||
|
||||
import_stats() {
|
||||
local port
|
||||
port=$(get_port)
|
||||
|
||||
echo "Importing usage statistics..."
|
||||
IMPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST \
|
||||
-H "X-Management-Key: ${API_SECRET}" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d @"${STATS_FILE}" \
|
||||
"http://localhost:${port}/v0/management/usage/import")
|
||||
IMPORT_CODE=$(echo "${IMPORT_RESPONSE}" | tail -n1)
|
||||
IMPORT_BODY=$(echo "${IMPORT_RESPONSE}" | sed '$d')
|
||||
|
||||
if [[ "${IMPORT_CODE}" == "200" ]]; then
|
||||
echo "Statistics imported successfully"
|
||||
else
|
||||
echo "Import failed (HTTP ${IMPORT_CODE}): ${IMPORT_BODY}"
|
||||
fi
|
||||
|
||||
rm -f "${STATS_FILE}"
|
||||
}
|
||||
|
||||
wait_for_service() {
|
||||
local port
|
||||
port=$(get_port)
|
||||
|
||||
echo "Waiting for service to be ready..."
|
||||
for i in {1..30}; do
|
||||
if curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
sleep 2
|
||||
}
|
||||
|
||||
if [[ "${1:-}" == "--with-usage" ]]; then
|
||||
WITH_USAGE=true
|
||||
export_stats_api_secret
|
||||
fi
|
||||
|
||||
# --- Step 1: Choose Environment ---
|
||||
echo "Please select an option:"
|
||||
echo "1) Run using Pre-built Image (Recommended)"
|
||||
@@ -18,7 +124,14 @@ read -r -p "Enter choice [1-2]: " choice
|
||||
case "$choice" in
|
||||
1)
|
||||
echo "--- Running with Pre-built Image ---"
|
||||
if [[ "${WITH_USAGE}" == "true" ]]; then
|
||||
export_stats
|
||||
fi
|
||||
docker compose up -d --remove-orphans --no-build
|
||||
if [[ "${WITH_USAGE}" == "true" ]]; then
|
||||
wait_for_service
|
||||
import_stats
|
||||
fi
|
||||
echo "Services are starting from remote image."
|
||||
echo "Run 'docker compose logs -f' to see the logs."
|
||||
;;
|
||||
@@ -38,16 +151,25 @@ case "$choice" in
|
||||
|
||||
# Build and start the services with a local-only image tag
|
||||
export CLI_PROXY_IMAGE="cli-proxy-api:local"
|
||||
|
||||
|
||||
echo "Building the Docker image..."
|
||||
docker compose build \
|
||||
--build-arg VERSION="${VERSION}" \
|
||||
--build-arg COMMIT="${COMMIT}" \
|
||||
--build-arg BUILD_DATE="${BUILD_DATE}"
|
||||
|
||||
if [[ "${WITH_USAGE}" == "true" ]]; then
|
||||
export_stats
|
||||
fi
|
||||
|
||||
echo "Starting the services..."
|
||||
docker compose up -d --remove-orphans --pull never
|
||||
|
||||
if [[ "${WITH_USAGE}" == "true" ]]; then
|
||||
wait_for_service
|
||||
import_stats
|
||||
fi
|
||||
|
||||
echo "Build complete. Services are starting."
|
||||
echo "Run 'docker compose logs -f' to see the logs."
|
||||
;;
|
||||
@@ -55,4 +177,4 @@ case "$choice" in
|
||||
echo "Invalid choice. Please enter 1 or 2."
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
esac
|
||||
|
||||
@@ -22,7 +22,7 @@ services:
|
||||
- "51121:51121"
|
||||
- "11451:11451"
|
||||
volumes:
|
||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
||||
- ./auths:/root/.cli-proxy-api
|
||||
- ./logs:/CLIProxyAPI/logs
|
||||
- ${CLI_PROXY_CONFIG_PATH:-./config.yaml}:/CLIProxyAPI/config.yaml
|
||||
- ${CLI_PROXY_AUTH_PATH:-./auths}:/root/.cli-proxy-api
|
||||
- ${CLI_PROXY_LOG_PATH:-./logs}:/CLIProxyAPI/logs
|
||||
restart: unless-stopped
|
||||
|
||||
@@ -1,443 +0,0 @@
|
||||
# Amp CLI Integration Guide
|
||||
|
||||
This guide explains how to use CLIProxyAPI with Amp CLI and Amp IDE extensions, enabling you to use your existing Google/ChatGPT/Claude subscriptions (via OAuth) with Amp's CLI.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Which Providers Should You Authenticate?](#which-providers-should-you-authenticate)
|
||||
- [Architecture](#architecture)
|
||||
- [Configuration](#configuration)
|
||||
- [Model Mapping Configuration](#model-mapping-configuration)
|
||||
- [Setup](#setup)
|
||||
- [Usage](#usage)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
## Overview
|
||||
|
||||
The Amp CLI integration adds specialized routing to support Amp's API patterns while maintaining full compatibility with all existing CLIProxyAPI features. This allows you to use both traditional CLIProxyAPI features and Amp CLI with the same proxy server.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Provider route aliases**: Maps Amp's `/api/provider/{provider}/v1...` patterns to CLIProxyAPI handlers
|
||||
- **Management proxy**: Forwards OAuth and account management requests to Amp's control plane
|
||||
- **Smart fallback**: Automatically routes unconfigured models to ampcode.com
|
||||
- **Model mapping**: Route unavailable models to alternatives you have access to (e.g., `claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- **Secret management**: Configurable precedence (config > env > file) with 5-minute caching
|
||||
- **Security-first**: Management routes restricted to localhost by default
|
||||
- **Automatic gzip handling**: Decompresses responses from Amp upstream
|
||||
|
||||
### What You Can Do
|
||||
|
||||
- Use Amp CLI with your Google account (Gemini 3 Pro Preview, Gemini 2.5 Pro, Gemini 2.5 Flash)
|
||||
- Use Amp CLI with your ChatGPT Plus/Pro subscription (GPT-5, GPT-5 Codex models)
|
||||
- Use Amp CLI with your Claude Pro/Max subscription (Claude Sonnet 4.5, Opus 4.1)
|
||||
- Use Amp IDE extensions (VS Code, Cursor, Windsurf, etc.) with the same proxy
|
||||
- Run multiple CLI tools (Factory + Amp) through one proxy server
|
||||
- Route unconfigured models automatically through ampcode.com
|
||||
|
||||
### Which Providers Should You Authenticate?
|
||||
|
||||
**Important**: The providers you need to authenticate depend on which models and features your installed version of Amp currently uses. Amp employs different providers for various agent modes and specialized subagents:
|
||||
|
||||
- **Smart mode**: Uses Google/Gemini models (Gemini 3 Pro)
|
||||
- **Rush mode**: Uses Anthropic/Claude models (Claude Haiku 4.5)
|
||||
- **Oracle subagent**: Uses OpenAI/GPT models (GPT-5 medium reasoning)
|
||||
- **Librarian subagent**: Uses Anthropic/Claude models (Claude Sonnet 4.5)
|
||||
- **Search subagent**: Uses Anthropic/Claude models (Claude Haiku 4.5)
|
||||
- **Review feature**: Uses Google/Gemini models (Gemini 2.5 Flash-Lite)
|
||||
|
||||
For the most current information about which models Amp uses, see the **[Amp Models Documentation](https://ampcode.com/models)**.
|
||||
|
||||
#### Fallback Behavior
|
||||
|
||||
CLIProxyAPI uses a smart fallback system:
|
||||
|
||||
1. **Provider authenticated locally** (`--login`, `--codex-login`, `--claude-login`):
|
||||
- Requests use **your OAuth subscription** (ChatGPT Plus/Pro, Claude Pro/Max, Google account)
|
||||
- You benefit from your subscription's included usage quotas
|
||||
- No Amp credits consumed
|
||||
|
||||
2. **Provider NOT authenticated locally**:
|
||||
- Requests automatically forward to **ampcode.com**
|
||||
- Uses Amp's backend provider connections
|
||||
- **Requires Amp credits** if the provider is paid (OpenAI, Anthropic paid tiers)
|
||||
- May result in errors if Amp credit balance is insufficient
|
||||
|
||||
**Recommendation**: Authenticate all providers you have subscriptions for to maximize value and minimize Amp credit usage. If you don't have subscriptions to all providers Amp uses, ensure you have sufficient Amp credits available for fallback requests.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Request Flow
|
||||
|
||||
```
|
||||
Amp CLI/IDE
|
||||
↓
|
||||
├─ Provider API requests (/api/provider/{provider}/v1/...)
|
||||
│ ↓
|
||||
│ ├─ Model configured locally?
|
||||
│ │ YES → Use local OAuth tokens (OpenAI/Claude/Gemini handlers)
|
||||
│ │ NO ↓
|
||||
│ │ ├─ Model mapping configured?
|
||||
│ │ │ YES → Rewrite model → Use local handler (free)
|
||||
│ │ │ NO → Forward to ampcode.com (uses Amp credits)
|
||||
│ ↓
|
||||
│ Response
|
||||
│
|
||||
└─ Management requests (/api/auth, /api/user, /api/threads, ...)
|
||||
↓
|
||||
├─ Localhost check (security)
|
||||
↓
|
||||
└─ Reverse proxy to ampcode.com
|
||||
↓
|
||||
Response (auto-decompressed if gzipped)
|
||||
```
|
||||
|
||||
### Components
|
||||
|
||||
The Amp integration is implemented as a modular routing module (`internal/api/modules/amp/`) with these components:
|
||||
|
||||
1. **Route Aliases** (`routes.go`): Maps Amp-style paths to standard handlers
|
||||
2. **Reverse Proxy** (`proxy.go`): Forwards management requests to ampcode.com
|
||||
3. **Fallback Handler** (`fallback_handlers.go`): Routes unconfigured models to ampcode.com
|
||||
4. **Secret Management** (`secret.go`): Multi-source API key resolution with caching
|
||||
5. **Main Module** (`amp.go`): Orchestrates registration and configuration
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
Add these fields to your `config.yaml`:
|
||||
|
||||
```yaml
|
||||
# Amp upstream control plane (required for management routes)
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
|
||||
# Optional: Override API key (otherwise uses env or file)
|
||||
# amp-upstream-api-key: "your-amp-api-key"
|
||||
|
||||
# Security: restrict management routes to localhost (recommended)
|
||||
amp-restrict-management-to-localhost: true
|
||||
```
|
||||
|
||||
### Model Mapping Configuration
|
||||
|
||||
When Amp CLI requests a model that you don't have access to, you can configure mappings to route those requests to alternative models that you DO have available. This avoids consuming Amp credits for models you could handle locally.
|
||||
|
||||
```yaml
|
||||
# Route unavailable models to alternatives
|
||||
amp-model-mappings:
|
||||
# Example: Route Claude Opus 4.5 requests to Claude Sonnet 4
|
||||
- from: "claude-opus-4.5"
|
||||
to: "claude-sonnet-4"
|
||||
|
||||
# Example: Route GPT-5 requests to Gemini 2.5 Pro
|
||||
- from: "gpt-5"
|
||||
to: "gemini-2.5-pro"
|
||||
|
||||
# Example: Map older model names to newer versions
|
||||
- from: "claude-3-opus-20240229"
|
||||
to: "claude-3-5-sonnet-20241022"
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
|
||||
1. Amp CLI requests a model (e.g., `claude-opus-4.5`)
|
||||
2. CLIProxyAPI checks if a local provider is available for that model
|
||||
3. If not available, it checks the model mappings
|
||||
4. If a mapping exists, the request is rewritten to use the target model
|
||||
5. The request is then handled locally (free, using your OAuth subscription)
|
||||
|
||||
**Benefits:**
|
||||
- **Save Amp credits**: Use your local subscriptions instead of forwarding to ampcode.com
|
||||
- **Hot-reload**: Mappings can be updated without restarting the proxy
|
||||
- **Structured logging**: Clear logs show when mappings are applied
|
||||
|
||||
**Routing Decision Logs:**
|
||||
|
||||
The proxy logs each routing decision with structured fields:
|
||||
|
||||
```
|
||||
[AMP] Using local provider for model: gemini-2.5-pro # Local provider (free)
|
||||
[AMP] Model mapped: claude-opus-4.5 -> claude-sonnet-4 # Mapping applied (free)
|
||||
[AMP] Forwarding to ampcode.com (uses Amp credits) - model_id: gpt-5 # Fallback (costs credits)
|
||||
```
|
||||
|
||||
### Secret Resolution Precedence
|
||||
|
||||
The Amp module resolves API keys using this precedence order:
|
||||
|
||||
| Source | Key | Priority | Cache |
|
||||
|--------|-----|----------|-------|
|
||||
| Config file | `amp-upstream-api-key` | High | No |
|
||||
| Environment | `AMP_API_KEY` | Medium | No |
|
||||
| Amp secrets file | `~/.local/share/amp/secrets.json` | Low | 5 min |
|
||||
|
||||
**Recommendation**: Use the Amp secrets file (lowest precedence) for normal usage. This file is automatically managed by `amp login`.
|
||||
|
||||
### Security Settings
|
||||
|
||||
**`amp-restrict-management-to-localhost`** (default: `true`)
|
||||
|
||||
When enabled, management routes (`/api/auth`, `/api/user`, `/api/threads`, etc.) only accept connections from localhost (127.0.0.1, ::1). This prevents:
|
||||
- Drive-by browser attacks
|
||||
- Remote access to management endpoints
|
||||
- CORS-based attacks
|
||||
- Header spoofing attacks (e.g., `X-Forwarded-For: 127.0.0.1`)
|
||||
|
||||
#### How It Works
|
||||
|
||||
This restriction uses the **actual TCP connection address** (`RemoteAddr`), not HTTP headers like `X-Forwarded-For`. This prevents header spoofing attacks but has important implications:
|
||||
|
||||
- ✅ **Works for direct connections**: Running CLIProxyAPI directly on your machine or server
|
||||
- ⚠️ **May not work behind reverse proxies**: If deploying behind nginx, Cloudflare, or other proxies, the connection will appear to come from the proxy's IP, not localhost
|
||||
|
||||
#### Reverse Proxy Deployments
|
||||
|
||||
If you need to run CLIProxyAPI behind a reverse proxy (nginx, Caddy, Cloudflare Tunnel, etc.):
|
||||
|
||||
1. **Disable the localhost restriction**:
|
||||
```yaml
|
||||
amp-restrict-management-to-localhost: false
|
||||
```
|
||||
|
||||
2. **Use alternative security measures**:
|
||||
- Firewall rules restricting access to management routes
|
||||
- Proxy-level authentication (HTTP Basic Auth, OAuth)
|
||||
- Network-level isolation (VPN, Tailscale, Cloudflare Access)
|
||||
- Bind CLIProxyAPI to `127.0.0.1` only and access via SSH tunnel
|
||||
|
||||
3. **Example nginx configuration** (blocks external access to management routes):
|
||||
```nginx
|
||||
location /api/auth { deny all; }
|
||||
location /api/user { deny all; }
|
||||
location /api/threads { deny all; }
|
||||
location /api/internal { deny all; }
|
||||
```
|
||||
|
||||
**Important**: Only disable `amp-restrict-management-to-localhost` if you understand the security implications and have other protections in place.
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Configure CLIProxyAPI
|
||||
|
||||
Create or edit `config.yaml`:
|
||||
|
||||
```yaml
|
||||
port: 8317
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
# Amp integration
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
amp-restrict-management-to-localhost: true
|
||||
|
||||
# Other standard settings...
|
||||
debug: false
|
||||
logging-to-file: true
|
||||
```
|
||||
|
||||
### 2. Authenticate with Providers
|
||||
|
||||
Run OAuth login for the providers you want to use:
|
||||
|
||||
**Google Account (Gemini 2.5 Pro, Gemini 2.5 Flash, Gemini 3 Pro Preview):**
|
||||
```bash
|
||||
./cli-proxy-api --login
|
||||
```
|
||||
|
||||
**ChatGPT Plus/Pro (GPT-5, GPT-5 Codex):**
|
||||
```bash
|
||||
./cli-proxy-api --codex-login
|
||||
```
|
||||
|
||||
**Claude Pro/Max (Claude Sonnet 4.5, Opus 4.1):**
|
||||
```bash
|
||||
./cli-proxy-api --claude-login
|
||||
```
|
||||
|
||||
Tokens are saved to:
|
||||
- Gemini: `~/.cli-proxy-api/gemini-<email>.json`
|
||||
- OpenAI Codex: `~/.cli-proxy-api/codex-<email>.json`
|
||||
- Claude: `~/.cli-proxy-api/claude-<email>.json`
|
||||
|
||||
### 3. Start the Proxy
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --config config.yaml
|
||||
```
|
||||
|
||||
Or run in background with tmux (recommended for remote servers):
|
||||
|
||||
```bash
|
||||
tmux new-session -d -s proxy "./cli-proxy-api --config config.yaml"
|
||||
```
|
||||
|
||||
### 4. Configure Amp CLI
|
||||
|
||||
#### Option A: Settings File
|
||||
|
||||
Edit `~/.config/amp/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"amp.url": "http://localhost:8317"
|
||||
}
|
||||
```
|
||||
|
||||
#### Option B: Environment Variable
|
||||
|
||||
```bash
|
||||
export AMP_URL=http://localhost:8317
|
||||
```
|
||||
|
||||
### 5. Login and Use Amp
|
||||
|
||||
Login through the proxy (proxied to ampcode.com):
|
||||
|
||||
```bash
|
||||
amp login
|
||||
```
|
||||
|
||||
Use Amp as normal:
|
||||
|
||||
```bash
|
||||
amp "Write a hello world program in Python"
|
||||
```
|
||||
|
||||
### 6. (Optional) Configure Amp IDE Extension
|
||||
|
||||
The proxy also works with Amp IDE extensions for VS Code, Cursor, Windsurf, etc.
|
||||
|
||||
1. Open Amp extension settings in your IDE
|
||||
2. Set **Amp URL** to `http://localhost:8317`
|
||||
3. Login with your Amp account
|
||||
4. Start using Amp in your IDE
|
||||
|
||||
Both CLI and IDE can use the proxy simultaneously.
|
||||
|
||||
## Usage
|
||||
|
||||
### Supported Routes
|
||||
|
||||
#### Provider Aliases (Always Available)
|
||||
|
||||
These routes work even without `amp-upstream-url` configured:
|
||||
|
||||
- `/api/provider/openai/v1/chat/completions`
|
||||
- `/api/provider/openai/v1/responses`
|
||||
- `/api/provider/anthropic/v1/messages`
|
||||
- `/api/provider/google/v1beta/models/:action`
|
||||
|
||||
Amp CLI calls these routes with your OAuth-authenticated models configured in CLIProxyAPI.
|
||||
|
||||
#### Management Routes (Require `amp-upstream-url`)
|
||||
|
||||
These routes are proxied to ampcode.com:
|
||||
|
||||
- `/api/auth` - Authentication
|
||||
- `/api/user` - User profile
|
||||
- `/api/meta` - Metadata
|
||||
- `/api/threads` - Conversation threads
|
||||
- `/api/telemetry` - Usage telemetry
|
||||
- `/api/internal` - Internal APIs
|
||||
|
||||
**Security**: Restricted to localhost by default.
|
||||
|
||||
### Model Fallback Behavior
|
||||
|
||||
When Amp requests a model:
|
||||
|
||||
1. **Check local configuration**: Does CLIProxyAPI have OAuth tokens for this model's provider?
|
||||
2. **If YES**: Route to local handler (use your OAuth subscription)
|
||||
3. **If NO**: Check if a model mapping exists
|
||||
4. **If mapping exists**: Rewrite request to mapped model → Route to local handler (free)
|
||||
5. **If no mapping**: Forward to ampcode.com (uses Amp credits)
|
||||
|
||||
This enables seamless mixed usage:
|
||||
- Models you've configured (Gemini, ChatGPT, Claude) → Your OAuth subscriptions
|
||||
- Models with mappings configured → Routed to alternative local models (free)
|
||||
- Models you haven't configured and have no mapping → Amp's default providers (uses credits)
|
||||
|
||||
### Example API Calls
|
||||
|
||||
**Chat completion with local OAuth:**
|
||||
```bash
|
||||
curl http://localhost:8317/api/provider/openai/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-5",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}'
|
||||
```
|
||||
|
||||
**Management endpoint (localhost only):**
|
||||
```bash
|
||||
curl http://localhost:8317/api/user
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
|---------|--------------|-----|
|
||||
| 404 on `/api/provider/...` | Incorrect route path | Ensure exact path: `/api/provider/{provider}/v1...` |
|
||||
| 403 on `/api/user` | Non-localhost request | Run from same machine or disable `amp-restrict-management-to-localhost` (not recommended) |
|
||||
| 401/403 from provider | Missing/expired OAuth | Re-run `--codex-login` or `--claude-login` |
|
||||
| Amp gzip errors | Response decompression issue | Update to latest build; auto-decompression should handle this |
|
||||
| Models not using proxy | Wrong Amp URL | Verify `amp.url` setting or `AMP_URL` environment variable |
|
||||
| CORS errors | Protected management endpoint | Use CLI/terminal, not browser |
|
||||
|
||||
### Diagnostics
|
||||
|
||||
**Check proxy logs:**
|
||||
```bash
|
||||
# If logging-to-file: true
|
||||
tail -f logs/requests.log
|
||||
|
||||
# If running in tmux
|
||||
tmux attach-session -t proxy
|
||||
```
|
||||
|
||||
**Enable debug mode** (temporarily):
|
||||
```yaml
|
||||
debug: true
|
||||
```
|
||||
|
||||
**Test basic connectivity:**
|
||||
```bash
|
||||
# Check if proxy is running
|
||||
curl http://localhost:8317/v1/models
|
||||
|
||||
# Check Amp-specific route
|
||||
curl http://localhost:8317/api/provider/openai/v1/models
|
||||
```
|
||||
|
||||
**Verify Amp configuration:**
|
||||
```bash
|
||||
# Check if Amp is using proxy
|
||||
amp config get amp.url
|
||||
|
||||
# Or check environment
|
||||
echo $AMP_URL
|
||||
```
|
||||
|
||||
### Security Checklist
|
||||
|
||||
- ✅ Keep `amp-restrict-management-to-localhost: true` (default)
|
||||
- ✅ Don't expose proxy publicly (bind to localhost or use firewall/VPN)
|
||||
- ✅ Use the Amp secrets file (`~/.local/share/amp/secrets.json`) managed by `amp login`
|
||||
- ✅ Rotate OAuth tokens periodically by re-running login commands
|
||||
- ✅ Store config and auth-dir on encrypted disk if handling sensitive data
|
||||
- ✅ Keep proxy binary up to date for security fixes
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [CLIProxyAPI Main Documentation](https://help.router-for.me/)
|
||||
- [Amp CLI Official Manual](https://ampcode.com/manual)
|
||||
- [Management API Reference](https://help.router-for.me/management/api)
|
||||
- [SDK Documentation](sdk-usage.md)
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This integration is for personal/educational use. Using reverse proxies or alternate API bases may violate provider Terms of Service. You are solely responsible for how you use this software. Accounts may be rate-limited, locked, or banned. No warranties. Use at your own risk.
|
||||
@@ -1,392 +0,0 @@
|
||||
# Amp CLI 集成指南
|
||||
|
||||
本指南说明如何在 Amp CLI 和 Amp IDE 扩展中使用 CLIProxyAPI,通过 OAuth 让你能够把已有的 Google/ChatGPT/Claude 订阅与 Amp 的 CLI 一起使用。
|
||||
|
||||
## 目录
|
||||
|
||||
- [概述](#概述)
|
||||
- [应该认证哪些服务提供商?](#应该认证哪些服务提供商)
|
||||
- [架构](#架构)
|
||||
- [配置](#配置)
|
||||
- [设置](#设置)
|
||||
- [用法](#用法)
|
||||
- [故障排查](#故障排查)
|
||||
|
||||
## 概述
|
||||
|
||||
Amp CLI 集成为 Amp 的 API 模式添加了专用路由,同时保持与现有 CLIProxyAPI 功能的完全兼容。这样你可以在同一个代理服务器上同时使用传统 CLIProxyAPI 功能和 Amp CLI。
|
||||
|
||||
### 主要特性
|
||||
|
||||
- **提供者路由别名**:将 Amp 的 `/api/provider/{provider}/v1...` 路径映射到 CLIProxyAPI 处理器
|
||||
- **管理代理**:将 OAuth 和账号管理请求转发到 Amp 控制平面
|
||||
- **智能回退**:自动将未配置的模型路由到 ampcode.com
|
||||
- **密钥管理**:可配置优先级(配置 > 环境变量 > 文件),缓存 5 分钟
|
||||
- **安全优先**:管理路由默认限制为 localhost
|
||||
- **自动 gzip 处理**:自动解压来自 Amp 上游的响应
|
||||
|
||||
### 你可以做什么
|
||||
|
||||
- 使用 Amp CLI 搭配你的 Google 账号(Gemini 3 Pro Preview、Gemini 2.5 Pro、Gemini 2.5 Flash)
|
||||
- 使用 Amp CLI 搭配你的 ChatGPT Plus/Pro 订阅(GPT-5、GPT-5 Codex 模型)
|
||||
- 使用 Amp CLI 搭配你的 Claude Pro/Max 订阅(Claude Sonnet 4.5、Opus 4.1)
|
||||
- 将 Amp IDE 扩展(VS Code、Cursor、Windsurf 等)与同一个代理一起使用
|
||||
- 通过一个代理同时运行多个 CLI 工具(Factory + Amp)
|
||||
- 将未配置的模型自动路由到 ampcode.com
|
||||
|
||||
### 应该认证哪些服务提供商?
|
||||
|
||||
**重要**:需要认证的提供商取决于你安装的 Amp 版本当前使用的模型和功能。Amp 的不同智能模式和子代理会使用不同的提供商:
|
||||
|
||||
- **Smart 模式**:使用 Google/Gemini 模型(Gemini 3 Pro)
|
||||
- **Rush 模式**:使用 Anthropic/Claude 模型(Claude Haiku 4.5)
|
||||
- **Oracle 子代理**:使用 OpenAI/GPT 模型(GPT-5 medium reasoning)
|
||||
- **Librarian 子代理**:使用 Anthropic/Claude 模型(Claude Sonnet 4.5)
|
||||
- **Search 子代理**:使用 Anthropic/Claude 模型(Claude Haiku 4.5)
|
||||
- **Review 功能**:使用 Google/Gemini 模型(Gemini 2.5 Flash-Lite)
|
||||
|
||||
有关 Amp 当前使用哪些模型的最新信息,请参阅 **[Amp 模型文档](https://ampcode.com/models)**。
|
||||
|
||||
#### 回退行为
|
||||
|
||||
CLIProxyAPI 采用智能回退机制:
|
||||
|
||||
1. **本地已认证提供商**(`--login`、`--codex-login`、`--claude-login`):
|
||||
- 请求使用**你的 OAuth 订阅**(ChatGPT Plus/Pro、Claude Pro/Max、Google 账号)
|
||||
- 享受订阅自带的额度
|
||||
- 不消耗 Amp 额度
|
||||
|
||||
2. **本地未认证提供商**:
|
||||
- 请求自动转发到 **ampcode.com**
|
||||
- 使用 Amp 的后端提供商连接
|
||||
- 如果提供商是付费的(OpenAI、Anthropic 付费档),**需要消耗 Amp 额度**
|
||||
- 若 Amp 额度不足,可能产生错误
|
||||
|
||||
**建议**:对你有订阅的所有提供商都进行认证,以最大化价值并尽量减少 Amp 额度消耗。如果没有覆盖 Amp 使用的全部提供商,请确保为回退请求准备足够的 Amp 额度。
|
||||
|
||||
## 架构
|
||||
|
||||
### 请求流
|
||||
|
||||
```
|
||||
Amp CLI/IDE
|
||||
↓
|
||||
├─ Provider API requests (/api/provider/{provider}/v1/...)
|
||||
│ ↓
|
||||
│ ├─ Model configured locally?
|
||||
│ │ YES → Use local OAuth tokens (OpenAI/Claude/Gemini handlers)
|
||||
│ │ NO → Forward to ampcode.com (reverse proxy)
|
||||
│ ↓
|
||||
│ Response
|
||||
│
|
||||
└─ Management requests (/api/auth, /api/user, /api/threads, ...)
|
||||
↓
|
||||
├─ Localhost check (security)
|
||||
↓
|
||||
└─ Reverse proxy to ampcode.com
|
||||
↓
|
||||
Response (auto-decompressed if gzipped)
|
||||
```
|
||||
|
||||
### 组件
|
||||
|
||||
Amp 集成以模块化路由模块(`internal/api/modules/amp/`)实现,包含以下组件:
|
||||
|
||||
1. **路由别名**(`routes.go`):将 Amp 风格的路径映射到标准处理器
|
||||
2. **反向代理**(`proxy.go`):将管理请求转发到 ampcode.com
|
||||
3. **回退处理器**(`fallback_handlers.go`):将未配置的模型路由到 ampcode.com
|
||||
4. **密钥管理**(`secret.go`):多来源 API 密钥解析并带缓存
|
||||
5. **主模块**(`amp.go`):负责注册和配置
|
||||
|
||||
## 配置
|
||||
|
||||
### 基础配置
|
||||
|
||||
在 `config.yaml` 中新增以下字段:
|
||||
|
||||
```yaml
|
||||
# Amp 上游控制平面(管理路由必需)
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
|
||||
# 可选:覆盖 API key(否则使用环境变量或文件)
|
||||
# amp-upstream-api-key: "your-amp-api-key"
|
||||
|
||||
# 安全性:将管理路由限制为 localhost(推荐)
|
||||
amp-restrict-management-to-localhost: true
|
||||
```
|
||||
|
||||
### 密钥解析优先级
|
||||
|
||||
Amp 模块以如下优先级解析 API key:
|
||||
|
||||
| 来源 | 键名 | 优先级 | 缓存 |
|
||||
|------|------|--------|------|
|
||||
| 配置文件 | `amp-upstream-api-key` | 高 | 无 |
|
||||
| 环境变量 | `AMP_API_KEY` | 中 | 无 |
|
||||
| Amp 密钥文件 | `~/.local/share/amp/secrets.json` | 低 | 5 分钟 |
|
||||
|
||||
**建议**:日常使用时采用 Amp 密钥文件(最低优先级)。该文件由 `amp login` 自动管理。
|
||||
|
||||
### 安全设置
|
||||
|
||||
**`amp-restrict-management-to-localhost`**(默认:`true`)
|
||||
|
||||
启用后,管理路由(`/api/auth`、`/api/user`、`/api/threads` 等)只接受来自 localhost(127.0.0.1、::1)的连接,可防止:
|
||||
- 浏览器探测式攻击
|
||||
- 对管理端点的远程访问
|
||||
- 基于 CORS 的攻击
|
||||
- 伪造头攻击(例如 `X-Forwarded-For: 127.0.0.1`)
|
||||
|
||||
#### 工作原理
|
||||
|
||||
此限制使用**实际的 TCP 连接地址**(`RemoteAddr`),而非 `X-Forwarded-For` 等 HTTP 头,能防止头部伪造,但有重要影响:
|
||||
|
||||
- ✅ **直接连接可用**:在本机或服务器直接运行 CLIProxyAPI 时适用
|
||||
- ⚠️ **可能不适用于反向代理场景**:部署在 nginx、Cloudflare 等代理后,请求源会显示为代理 IP 而非 localhost
|
||||
|
||||
#### 反向代理部署
|
||||
|
||||
若需要在反向代理(nginx、Caddy、Cloudflare Tunnel 等)后运行 CLIProxyAPI:
|
||||
|
||||
1. **关闭 localhost 限制**:
|
||||
```yaml
|
||||
amp-restrict-management-to-localhost: false
|
||||
```
|
||||
|
||||
2. **使用替代安全措施**:
|
||||
- 防火墙规则限制管理路由访问
|
||||
- 代理层认证(HTTP Basic Auth、OAuth)
|
||||
- 网络隔离(VPN、Tailscale、Cloudflare Access)
|
||||
- 将 CLIProxyAPI 仅绑定 `127.0.0.1`,并通过 SSH 隧道访问
|
||||
|
||||
3. **nginx 示例配置**(阻止外部访问管理路由):
|
||||
```nginx
|
||||
location /api/auth { deny all; }
|
||||
location /api/user { deny all; }
|
||||
location /api/threads { deny all; }
|
||||
location /api/internal { deny all; }
|
||||
```
|
||||
|
||||
**重要**:只有在理解安全影响并已采取其他防护措施时,才关闭 `amp-restrict-management-to-localhost`。
|
||||
|
||||
## 设置
|
||||
|
||||
### 1. 配置 CLIProxyAPI
|
||||
|
||||
创建或编辑 `config.yaml`:
|
||||
|
||||
```yaml
|
||||
port: 8317
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
# Amp 集成
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
amp-restrict-management-to-localhost: true
|
||||
|
||||
# 其他常规设置...
|
||||
debug: false
|
||||
logging-to-file: true
|
||||
```
|
||||
|
||||
### 2. 认证提供商
|
||||
|
||||
为要使用的提供商执行 OAuth 登录:
|
||||
|
||||
**Google 账号(Gemini 2.5 Pro、Gemini 2.5 Flash、Gemini 3 Pro Preview):**
|
||||
```bash
|
||||
./cli-proxy-api --login
|
||||
```
|
||||
|
||||
**ChatGPT Plus/Pro(GPT-5、GPT-5 Codex):**
|
||||
```bash
|
||||
./cli-proxy-api --codex-login
|
||||
```
|
||||
|
||||
**Claude Pro/Max(Claude Sonnet 4.5、Opus 4.1):**
|
||||
```bash
|
||||
./cli-proxy-api --claude-login
|
||||
```
|
||||
|
||||
令牌会保存到:
|
||||
- Gemini: `~/.cli-proxy-api/gemini-<email>.json`
|
||||
- OpenAI Codex: `~/.cli-proxy-api/codex-<email>.json`
|
||||
- Claude: `~/.cli-proxy-api/claude-<email>.json`
|
||||
|
||||
### 3. 启动代理
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --config config.yaml
|
||||
```
|
||||
|
||||
或使用 tmux 在后台运行(推荐用于远程服务器):
|
||||
|
||||
```bash
|
||||
tmux new-session -d -s proxy "./cli-proxy-api --config config.yaml"
|
||||
```
|
||||
|
||||
### 4. 配置 Amp CLI
|
||||
|
||||
#### 方案 A:配置文件
|
||||
|
||||
编辑 `~/.config/amp/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"amp.url": "http://localhost:8317"
|
||||
}
|
||||
```
|
||||
|
||||
#### 方案 B:环境变量
|
||||
|
||||
```bash
|
||||
export AMP_URL=http://localhost:8317
|
||||
```
|
||||
|
||||
### 5. 登录并使用 Amp
|
||||
|
||||
通过代理登录(请求会被代理到 ampcode.com):
|
||||
|
||||
```bash
|
||||
amp login
|
||||
```
|
||||
|
||||
像平常一样使用 Amp:
|
||||
|
||||
```bash
|
||||
amp "Write a hello world program in Python"
|
||||
```
|
||||
|
||||
### 6. (可选)配置 Amp IDE 扩展
|
||||
|
||||
该代理同样适用于 VS Code、Cursor、Windsurf 等 Amp IDE 扩展。
|
||||
|
||||
1. 在 IDE 中打开 Amp 扩展设置
|
||||
2. 将 **Amp URL** 设置为 `http://localhost:8317`
|
||||
3. 用你的 Amp 账号登录
|
||||
4. 在 IDE 中开始使用 Amp
|
||||
|
||||
CLI 和 IDE 可同时使用该代理。
|
||||
|
||||
## 用法
|
||||
|
||||
### 支持的路由
|
||||
|
||||
#### 提供商别名(始终可用)
|
||||
|
||||
这些路由即使未配置 `amp-upstream-url` 也可使用:
|
||||
|
||||
- `/api/provider/openai/v1/chat/completions`
|
||||
- `/api/provider/openai/v1/responses`
|
||||
- `/api/provider/anthropic/v1/messages`
|
||||
- `/api/provider/google/v1beta/models/:action`
|
||||
|
||||
Amp CLI 会使用你在 CLIProxyAPI 中通过 OAuth 认证的模型来调用这些路由。
|
||||
|
||||
#### 管理路由(需要 `amp-upstream-url`)
|
||||
|
||||
这些路由会被代理到 ampcode.com:
|
||||
|
||||
- `/api/auth` - 认证
|
||||
- `/api/user` - 用户资料
|
||||
- `/api/meta` - 元数据
|
||||
- `/api/threads` - 会话线程
|
||||
- `/api/telemetry` - 使用遥测
|
||||
- `/api/internal` - 内部 API
|
||||
|
||||
**安全性**:默认限制为 localhost。
|
||||
|
||||
### 模型回退行为
|
||||
|
||||
当 Amp 请求模型时:
|
||||
|
||||
1. **检查本地配置**:CLIProxyAPI 是否有该模型提供商的 OAuth 令牌?
|
||||
2. **如果有**:路由到本地处理器(使用你的 OAuth 订阅)
|
||||
3. **如果没有**:转发到 ampcode.com(使用 Amp 的默认路由)
|
||||
|
||||
这实现了无缝混用:
|
||||
- 你已配置的模型(Gemini、ChatGPT、Claude)→ 你的 OAuth 订阅
|
||||
- 未配置的模型 → Amp 的默认提供商
|
||||
|
||||
### 示例 API 调用
|
||||
|
||||
**使用本地 OAuth 的聊天补全:**
|
||||
```bash
|
||||
curl http://localhost:8317/api/provider/openai/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-5",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}'
|
||||
```
|
||||
|
||||
**管理端点(仅限 localhost):**
|
||||
```bash
|
||||
curl http://localhost:8317/api/user
|
||||
```
|
||||
|
||||
## 故障排查
|
||||
|
||||
### 常见问题
|
||||
|
||||
| 症状 | 可能原因 | 解决方案 |
|
||||
|------|----------|----------|
|
||||
| `/api/provider/...` 返回 404 | 路径错误 | 确保路径准确:`/api/provider/{provider}/v1...` |
|
||||
| `/api/user` 返回 403 | 非 localhost 请求 | 在同一机器上访问,或关闭 `amp-restrict-management-to-localhost`(不推荐) |
|
||||
| 提供商返回 401/403 | OAuth 缺失或过期 | 重新运行 `--codex-login` 或 `--claude-login` |
|
||||
| Amp gzip 错误 | 响应解压问题 | 更新到最新构建;自动解压应能处理 |
|
||||
| 模型未走代理 | Amp URL 设置错误 | 检查 `amp.url` 设置或 `AMP_URL` 环境变量 |
|
||||
| CORS 错误 | 受保护的管理端点 | 使用 CLI/终端而非浏览器 |
|
||||
|
||||
### 诊断
|
||||
|
||||
**查看代理日志:**
|
||||
```bash
|
||||
# 若 logging-to-file: true
|
||||
tail -f logs/requests.log
|
||||
|
||||
# 若运行在 tmux 中
|
||||
tmux attach-session -t proxy
|
||||
```
|
||||
|
||||
**临时开启调试模式:**
|
||||
```yaml
|
||||
debug: true
|
||||
```
|
||||
|
||||
**测试基础连通性:**
|
||||
```bash
|
||||
# 检查代理是否运行
|
||||
curl http://localhost:8317/v1/models
|
||||
|
||||
# 检查 Amp 特定路由
|
||||
curl http://localhost:8317/api/provider/openai/v1/models
|
||||
```
|
||||
|
||||
**验证 Amp 配置:**
|
||||
```bash
|
||||
# 检查 Amp 是否使用代理
|
||||
amp config get amp.url
|
||||
|
||||
# 或检查环境变量
|
||||
echo $AMP_URL
|
||||
```
|
||||
|
||||
### 安全清单
|
||||
|
||||
- ✅ 保持 `amp-restrict-management-to-localhost: true`(默认)
|
||||
- ✅ 不要将代理暴露到公共网络(绑定到 localhost 或使用防火墙/VPN)
|
||||
- ✅ 使用 `amp login` 管理的 Amp 密钥文件(`~/.local/share/amp/secrets.json`)
|
||||
- ✅ 定期重新登录轮换 OAuth 令牌
|
||||
- ✅ 若处理敏感数据,使用加密磁盘存储配置和 auth-dir
|
||||
- ✅ 保持代理二进制为最新版本以获取安全修复
|
||||
|
||||
## 其他资源
|
||||
|
||||
- [CLIProxyAPI 主文档](https://help.router-for.me/)
|
||||
- [Amp CLI 官方手册](https://ampcode.com/manual)
|
||||
- [管理 API 参考](https://help.router-for.me/management/api)
|
||||
- [SDK 文档](sdk-usage.md)
|
||||
|
||||
## 免责声明
|
||||
|
||||
此集成仅用于个人或教育用途。使用反向代理或替代 API 基址可能违反提供商的服务条款。你需要对自己的使用方式负责。账号可能会被限速、锁定或封禁。软件不附带任何保证,使用风险自负。
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -23,13 +24,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/logging"
|
||||
sdktr "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
@@ -122,7 +123,9 @@ func (MyExecutor) Execute(ctx context.Context, a *coreauth.Auth, req clipexec.Re
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Inject credentials via PrepareRequest hook.
|
||||
_ = (MyExecutor{}).PrepareRequest(httpReq, a)
|
||||
if errPrep := (MyExecutor{}).PrepareRequest(httpReq, a); errPrep != nil {
|
||||
return clipexec.Response{}, errPrep
|
||||
}
|
||||
|
||||
resp, errDo := client.Do(httpReq)
|
||||
if errDo != nil {
|
||||
@@ -130,13 +133,28 @@ func (MyExecutor) Execute(ctx context.Context, a *coreauth.Auth, req clipexec.Re
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
// Best-effort close; log if needed in real projects.
|
||||
fmt.Fprintf(os.Stderr, "close response body error: %v\n", errClose)
|
||||
}
|
||||
}()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return clipexec.Response{Payload: body}, nil
|
||||
}
|
||||
|
||||
func (MyExecutor) HttpRequest(ctx context.Context, a *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("myprov executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if errPrep := (MyExecutor{}).PrepareRequest(httpReq, a); errPrep != nil {
|
||||
return nil, errPrep
|
||||
}
|
||||
client := buildHTTPClient(a)
|
||||
return client.Do(httpReq)
|
||||
}
|
||||
|
||||
func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
|
||||
return clipexec.Response{}, errors.New("count tokens not implemented")
|
||||
}
|
||||
@@ -199,8 +217,8 @@ func main() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if err := svc.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
|
||||
panic(err)
|
||||
if errRun := svc.Run(ctx); errRun != nil && !errors.Is(errRun, context.Canceled) {
|
||||
panic(errRun)
|
||||
}
|
||||
_ = os.Stderr // keep os import used (demo only)
|
||||
_ = time.Second
|
||||
|
||||
140
examples/http-request/main.go
Normal file
140
examples/http-request/main.go
Normal file
@@ -0,0 +1,140 @@
|
||||
// Package main demonstrates how to use coreauth.Manager.HttpRequest/NewHttpRequest
|
||||
// to execute arbitrary HTTP requests with provider credentials injected.
|
||||
//
|
||||
// This example registers a minimal custom executor that injects an Authorization
|
||||
// header from auth.Attributes["api_key"], then performs two requests against
|
||||
// httpbin.org to show the injected headers.
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const providerKey = "echo"
|
||||
|
||||
// EchoExecutor is a minimal provider implementation for demonstration purposes.
|
||||
type EchoExecutor struct{}
|
||||
|
||||
func (EchoExecutor) Identifier() string { return providerKey }
|
||||
|
||||
func (EchoExecutor) PrepareRequest(req *http.Request, auth *coreauth.Auth) error {
|
||||
if req == nil || auth == nil {
|
||||
return nil
|
||||
}
|
||||
if auth.Attributes != nil {
|
||||
if apiKey := strings.TrimSpace(auth.Attributes["api_key"]); apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (EchoExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("echo executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if errPrep := (EchoExecutor{}).PrepareRequest(httpReq, auth); errPrep != nil {
|
||||
return nil, errPrep
|
||||
}
|
||||
return http.DefaultClient.Do(httpReq)
|
||||
}
|
||||
|
||||
func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
|
||||
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
|
||||
}
|
||||
|
||||
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
||||
return nil, errors.New("echo executor: ExecuteStream not implemented")
|
||||
}
|
||||
|
||||
func (EchoExecutor) Refresh(context.Context, *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return nil, errors.New("echo executor: Refresh not implemented")
|
||||
}
|
||||
|
||||
func (EchoExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
|
||||
return clipexec.Response{}, errors.New("echo executor: CountTokens not implemented")
|
||||
}
|
||||
|
||||
func main() {
|
||||
log.SetLevel(log.InfoLevel)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
core := coreauth.NewManager(nil, nil, nil)
|
||||
core.RegisterExecutor(EchoExecutor{})
|
||||
|
||||
auth := &coreauth.Auth{
|
||||
ID: "demo-echo",
|
||||
Provider: providerKey,
|
||||
Attributes: map[string]string{
|
||||
"api_key": "demo-api-key",
|
||||
},
|
||||
}
|
||||
|
||||
// Example 1: Build a prepared request and execute it using your own http.Client.
|
||||
reqPrepared, errReqPrepared := core.NewHttpRequest(
|
||||
ctx,
|
||||
auth,
|
||||
http.MethodGet,
|
||||
"https://httpbin.org/anything",
|
||||
nil,
|
||||
http.Header{"X-Example": []string{"prepared"}},
|
||||
)
|
||||
if errReqPrepared != nil {
|
||||
panic(errReqPrepared)
|
||||
}
|
||||
respPrepared, errDoPrepared := http.DefaultClient.Do(reqPrepared)
|
||||
if errDoPrepared != nil {
|
||||
panic(errDoPrepared)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := respPrepared.Body.Close(); errClose != nil {
|
||||
log.Errorf("close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
bodyPrepared, errReadPrepared := io.ReadAll(respPrepared.Body)
|
||||
if errReadPrepared != nil {
|
||||
panic(errReadPrepared)
|
||||
}
|
||||
fmt.Printf("Prepared request status: %d\n%s\n\n", respPrepared.StatusCode, bodyPrepared)
|
||||
|
||||
// Example 2: Execute a raw request via core.HttpRequest (auto inject + do).
|
||||
rawBody := []byte(`{"hello":"world"}`)
|
||||
rawReq, errRawReq := http.NewRequestWithContext(ctx, http.MethodPost, "https://httpbin.org/anything", bytes.NewReader(rawBody))
|
||||
if errRawReq != nil {
|
||||
panic(errRawReq)
|
||||
}
|
||||
rawReq.Header.Set("Content-Type", "application/json")
|
||||
rawReq.Header.Set("X-Example", "executed")
|
||||
|
||||
respExec, errDoExec := core.HttpRequest(ctx, auth, rawReq)
|
||||
if errDoExec != nil {
|
||||
panic(errDoExec)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := respExec.Body.Close(); errClose != nil {
|
||||
log.Errorf("close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
bodyExec, errReadExec := io.ReadAll(respExec.Body)
|
||||
if errReadExec != nil {
|
||||
panic(errReadExec)
|
||||
}
|
||||
fmt.Printf("Manager HttpRequest status: %d\n%s\n", respExec.StatusCode, bodyExec)
|
||||
}
|
||||
12
go.mod
12
go.mod
@@ -18,10 +18,11 @@ require (
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/tiktoken-go/tokenizer v0.7.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/net v0.46.0
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/term v0.36.0
|
||||
golang.org/x/sync v0.18.0
|
||||
golang.org/x/term v0.37.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
@@ -69,9 +70,8 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/sync v0.17.0 // indirect
|
||||
golang.org/x/sys v0.37.0 // indirect
|
||||
golang.org/x/text v0.30.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
)
|
||||
|
||||
24
go.sum
24
go.sum
@@ -160,23 +160,23 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
|
||||
704
internal/api/handlers/management/api_tools.go
Normal file
704
internal/api/handlers/management/api_tools.go
Normal file
@@ -0,0 +1,704 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/proxy"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
const defaultAPICallTimeout = 60 * time.Second
|
||||
|
||||
const (
|
||||
geminiOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
geminiOAuthClientSecret = "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"
|
||||
)
|
||||
|
||||
var geminiOAuthScopes = []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
}
|
||||
|
||||
const (
|
||||
antigravityOAuthClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityOAuthClientSecret = "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"
|
||||
)
|
||||
|
||||
var antigravityOAuthTokenURL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
type apiCallRequest struct {
|
||||
AuthIndexSnake *string `json:"auth_index"`
|
||||
AuthIndexCamel *string `json:"authIndex"`
|
||||
AuthIndexPascal *string `json:"AuthIndex"`
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Header map[string]string `json:"header"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type apiCallResponse struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Header map[string][]string `json:"header"`
|
||||
Body string `json:"body"`
|
||||
}
|
||||
|
||||
// APICall makes a generic HTTP request on behalf of the management API caller.
|
||||
// It is protected by the management middleware.
|
||||
//
|
||||
// Endpoint:
|
||||
//
|
||||
// POST /v0/management/api-call
|
||||
//
|
||||
// Authentication:
|
||||
//
|
||||
// Same as other management APIs (requires a management key and remote-management rules).
|
||||
// You can provide the key via:
|
||||
// - Authorization: Bearer <key>
|
||||
// - X-Management-Key: <key>
|
||||
//
|
||||
// Request JSON:
|
||||
// - auth_index / authIndex / AuthIndex (optional):
|
||||
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
|
||||
// If omitted or not found, credential-specific proxy/token substitution is skipped.
|
||||
// - method (required): HTTP method, e.g. GET, POST, PUT, PATCH, DELETE.
|
||||
// - url (required): Absolute URL including scheme and host, e.g. "https://api.example.com/v1/ping".
|
||||
// - header (optional): Request headers map.
|
||||
// Supports magic variable "$TOKEN$" which is replaced using the selected credential:
|
||||
// 1) metadata.access_token
|
||||
// 2) attributes.api_key
|
||||
// 3) metadata.token / metadata.id_token / metadata.cookie
|
||||
// Example: {"Authorization":"Bearer $TOKEN$"}.
|
||||
// Note: if you need to override the HTTP Host header, set header["Host"].
|
||||
// - data (optional): Raw request body as string (useful for POST/PUT/PATCH).
|
||||
//
|
||||
// Proxy selection (highest priority first):
|
||||
// 1. Selected credential proxy_url
|
||||
// 2. Global config proxy-url
|
||||
// 3. Direct connect (environment proxies are not used)
|
||||
//
|
||||
// Response JSON (returned with HTTP 200 when the APICall itself succeeds):
|
||||
// - status_code: Upstream HTTP status code.
|
||||
// - header: Upstream response headers.
|
||||
// - body: Upstream response body as string.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
|
||||
// -H "Authorization: Bearer <MANAGEMENT_KEY>" \
|
||||
// -H "Content-Type: application/json" \
|
||||
// -d '{"auth_index":"<AUTH_INDEX>","method":"GET","url":"https://api.example.com/v1/ping","header":{"Authorization":"Bearer $TOKEN$"}}'
|
||||
//
|
||||
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
|
||||
// -H "Authorization: Bearer 831227" \
|
||||
// -H "Content-Type: application/json" \
|
||||
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
|
||||
func (h *Handler) APICall(c *gin.Context) {
|
||||
var body apiCallRequest
|
||||
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
method := strings.ToUpper(strings.TrimSpace(body.Method))
|
||||
if method == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing method"})
|
||||
return
|
||||
}
|
||||
|
||||
urlStr := strings.TrimSpace(body.URL)
|
||||
if urlStr == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing url"})
|
||||
return
|
||||
}
|
||||
parsedURL, errParseURL := url.Parse(urlStr)
|
||||
if errParseURL != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
|
||||
return
|
||||
}
|
||||
|
||||
authIndex := firstNonEmptyString(body.AuthIndexSnake, body.AuthIndexCamel, body.AuthIndexPascal)
|
||||
auth := h.authByIndex(authIndex)
|
||||
|
||||
reqHeaders := body.Header
|
||||
if reqHeaders == nil {
|
||||
reqHeaders = map[string]string{}
|
||||
}
|
||||
|
||||
var hostOverride string
|
||||
var token string
|
||||
var tokenResolved bool
|
||||
var tokenErr error
|
||||
for key, value := range reqHeaders {
|
||||
if !strings.Contains(value, "$TOKEN$") {
|
||||
continue
|
||||
}
|
||||
if !tokenResolved {
|
||||
token, tokenErr = h.resolveTokenForAuth(c.Request.Context(), auth)
|
||||
tokenResolved = true
|
||||
}
|
||||
if auth != nil && token == "" {
|
||||
if tokenErr != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token refresh failed"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token not found"})
|
||||
return
|
||||
}
|
||||
if token == "" {
|
||||
continue
|
||||
}
|
||||
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
|
||||
}
|
||||
|
||||
var requestBody io.Reader
|
||||
if body.Data != "" {
|
||||
requestBody = strings.NewReader(body.Data)
|
||||
}
|
||||
|
||||
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
|
||||
if errNewRequest != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to build request"})
|
||||
return
|
||||
}
|
||||
|
||||
for key, value := range reqHeaders {
|
||||
if strings.EqualFold(key, "host") {
|
||||
hostOverride = strings.TrimSpace(value)
|
||||
continue
|
||||
}
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
if hostOverride != "" {
|
||||
req.Host = hostOverride
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: defaultAPICallTimeout,
|
||||
}
|
||||
httpClient.Transport = h.apiCallTransport(auth)
|
||||
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
log.WithError(errDo).Debug("management APICall request failed")
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "request failed"})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
respBody, errReadAll := io.ReadAll(resp.Body)
|
||||
if errReadAll != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to read response"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, apiCallResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header,
|
||||
Body: string(respBody),
|
||||
})
|
||||
}
|
||||
|
||||
func firstNonEmptyString(values ...*string) string {
|
||||
for _, v := range values {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
if out := strings.TrimSpace(*v); out != "" {
|
||||
return out
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func tokenValueForAuth(auth *coreauth.Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
if v := tokenValueFromMetadata(auth.Metadata); v != "" {
|
||||
return v
|
||||
}
|
||||
if auth.Attributes != nil {
|
||||
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
|
||||
if v := tokenValueFromMetadata(shared.MetadataSnapshot()); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
||||
if auth == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
if provider == "gemini-cli" {
|
||||
token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth)
|
||||
return token, errToken
|
||||
}
|
||||
if provider == "antigravity" {
|
||||
token, errToken := h.refreshAntigravityOAuthAccessToken(ctx, auth)
|
||||
return token, errToken
|
||||
}
|
||||
|
||||
return tokenValueForAuth(auth), nil
|
||||
}
|
||||
|
||||
func (h *Handler) refreshGeminiOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if auth == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
metadata, updater := geminiOAuthMetadata(auth)
|
||||
if len(metadata) == 0 {
|
||||
return "", fmt.Errorf("gemini oauth metadata missing")
|
||||
}
|
||||
|
||||
base := make(map[string]any)
|
||||
if tokenRaw, ok := metadata["token"].(map[string]any); ok && tokenRaw != nil {
|
||||
base = cloneMap(tokenRaw)
|
||||
}
|
||||
|
||||
var token oauth2.Token
|
||||
if len(base) > 0 {
|
||||
if raw, errMarshal := json.Marshal(base); errMarshal == nil {
|
||||
_ = json.Unmarshal(raw, &token)
|
||||
}
|
||||
}
|
||||
|
||||
if token.AccessToken == "" {
|
||||
token.AccessToken = stringValue(metadata, "access_token")
|
||||
}
|
||||
if token.RefreshToken == "" {
|
||||
token.RefreshToken = stringValue(metadata, "refresh_token")
|
||||
}
|
||||
if token.TokenType == "" {
|
||||
token.TokenType = stringValue(metadata, "token_type")
|
||||
}
|
||||
if token.Expiry.IsZero() {
|
||||
if expiry := stringValue(metadata, "expiry"); expiry != "" {
|
||||
if ts, errParseTime := time.Parse(time.RFC3339, expiry); errParseTime == nil {
|
||||
token.Expiry = ts
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
conf := &oauth2.Config{
|
||||
ClientID: geminiOAuthClientID,
|
||||
ClientSecret: geminiOAuthClientSecret,
|
||||
Scopes: geminiOAuthScopes,
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
|
||||
ctxToken := ctx
|
||||
httpClient := &http.Client{
|
||||
Timeout: defaultAPICallTimeout,
|
||||
Transport: h.apiCallTransport(auth),
|
||||
}
|
||||
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
|
||||
|
||||
src := conf.TokenSource(ctxToken, &token)
|
||||
currentToken, errToken := src.Token()
|
||||
if errToken != nil {
|
||||
return "", errToken
|
||||
}
|
||||
|
||||
merged := buildOAuthTokenMap(base, currentToken)
|
||||
fields := buildOAuthTokenFields(currentToken, merged)
|
||||
if updater != nil {
|
||||
updater(fields)
|
||||
}
|
||||
return strings.TrimSpace(currentToken.AccessToken), nil
|
||||
}
|
||||
|
||||
func (h *Handler) refreshAntigravityOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if auth == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
metadata := auth.Metadata
|
||||
if len(metadata) == 0 {
|
||||
return "", fmt.Errorf("antigravity oauth metadata missing")
|
||||
}
|
||||
|
||||
current := strings.TrimSpace(tokenValueFromMetadata(metadata))
|
||||
if current != "" && !antigravityTokenNeedsRefresh(metadata) {
|
||||
return current, nil
|
||||
}
|
||||
|
||||
refreshToken := stringValue(metadata, "refresh_token")
|
||||
if refreshToken == "" {
|
||||
return "", fmt.Errorf("antigravity refresh token missing")
|
||||
}
|
||||
|
||||
tokenURL := strings.TrimSpace(antigravityOAuthTokenURL)
|
||||
if tokenURL == "" {
|
||||
tokenURL = "https://oauth2.googleapis.com/token"
|
||||
}
|
||||
form := url.Values{}
|
||||
form.Set("client_id", antigravityOAuthClientID)
|
||||
form.Set("client_secret", antigravityOAuthClientSecret)
|
||||
form.Set("grant_type", "refresh_token")
|
||||
form.Set("refresh_token", refreshToken)
|
||||
|
||||
req, errReq := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
|
||||
if errReq != nil {
|
||||
return "", errReq
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: defaultAPICallTimeout,
|
||||
Transport: h.apiCallTransport(auth),
|
||||
}
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return "", errDo
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||
if errRead != nil {
|
||||
return "", errRead
|
||||
}
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
return "", fmt.Errorf("antigravity oauth token refresh failed: status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
if errUnmarshal := json.Unmarshal(bodyBytes, &tokenResp); errUnmarshal != nil {
|
||||
return "", errUnmarshal
|
||||
}
|
||||
|
||||
if strings.TrimSpace(tokenResp.AccessToken) == "" {
|
||||
return "", fmt.Errorf("antigravity oauth token refresh returned empty access_token")
|
||||
}
|
||||
|
||||
if auth.Metadata == nil {
|
||||
auth.Metadata = make(map[string]any)
|
||||
}
|
||||
now := time.Now()
|
||||
auth.Metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
|
||||
if strings.TrimSpace(tokenResp.RefreshToken) != "" {
|
||||
auth.Metadata["refresh_token"] = strings.TrimSpace(tokenResp.RefreshToken)
|
||||
}
|
||||
if tokenResp.ExpiresIn > 0 {
|
||||
auth.Metadata["expires_in"] = tokenResp.ExpiresIn
|
||||
auth.Metadata["timestamp"] = now.UnixMilli()
|
||||
auth.Metadata["expired"] = now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339)
|
||||
}
|
||||
auth.Metadata["type"] = "antigravity"
|
||||
|
||||
if h != nil && h.authManager != nil {
|
||||
auth.LastRefreshedAt = now
|
||||
auth.UpdatedAt = now
|
||||
_, _ = h.authManager.Update(ctx, auth)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(tokenResp.AccessToken), nil
|
||||
}
|
||||
|
||||
func antigravityTokenNeedsRefresh(metadata map[string]any) bool {
|
||||
// Refresh a bit early to avoid requests racing token expiry.
|
||||
const skew = 30 * time.Second
|
||||
|
||||
if metadata == nil {
|
||||
return true
|
||||
}
|
||||
if expStr, ok := metadata["expired"].(string); ok {
|
||||
if ts, errParse := time.Parse(time.RFC3339, strings.TrimSpace(expStr)); errParse == nil {
|
||||
return !ts.After(time.Now().Add(skew))
|
||||
}
|
||||
}
|
||||
expiresIn := int64Value(metadata["expires_in"])
|
||||
timestampMs := int64Value(metadata["timestamp"])
|
||||
if expiresIn > 0 && timestampMs > 0 {
|
||||
exp := time.UnixMilli(timestampMs).Add(time.Duration(expiresIn) * time.Second)
|
||||
return !exp.After(time.Now().Add(skew))
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func int64Value(raw any) int64 {
|
||||
switch typed := raw.(type) {
|
||||
case int:
|
||||
return int64(typed)
|
||||
case int32:
|
||||
return int64(typed)
|
||||
case int64:
|
||||
return typed
|
||||
case uint:
|
||||
return int64(typed)
|
||||
case uint32:
|
||||
return int64(typed)
|
||||
case uint64:
|
||||
if typed > uint64(^uint64(0)>>1) {
|
||||
return 0
|
||||
}
|
||||
return int64(typed)
|
||||
case float32:
|
||||
return int64(typed)
|
||||
case float64:
|
||||
return int64(typed)
|
||||
case json.Number:
|
||||
if i, errParse := typed.Int64(); errParse == nil {
|
||||
return i
|
||||
}
|
||||
case string:
|
||||
if s := strings.TrimSpace(typed); s != "" {
|
||||
if i, errParse := json.Number(s).Int64(); errParse == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) {
|
||||
if auth == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
|
||||
snapshot := shared.MetadataSnapshot()
|
||||
return snapshot, func(fields map[string]any) { shared.MergeMetadata(fields) }
|
||||
}
|
||||
return auth.Metadata, func(fields map[string]any) {
|
||||
if auth.Metadata == nil {
|
||||
auth.Metadata = make(map[string]any)
|
||||
}
|
||||
for k, v := range fields {
|
||||
auth.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stringValue(metadata map[string]any, key string) string {
|
||||
if len(metadata) == 0 || key == "" {
|
||||
return ""
|
||||
}
|
||||
if v, ok := metadata[key].(string); ok {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func cloneMap(in map[string]any) map[string]any {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildOAuthTokenMap(base map[string]any, tok *oauth2.Token) map[string]any {
|
||||
merged := cloneMap(base)
|
||||
if merged == nil {
|
||||
merged = make(map[string]any)
|
||||
}
|
||||
if tok == nil {
|
||||
return merged
|
||||
}
|
||||
if raw, errMarshal := json.Marshal(tok); errMarshal == nil {
|
||||
var tokenMap map[string]any
|
||||
if errUnmarshal := json.Unmarshal(raw, &tokenMap); errUnmarshal == nil {
|
||||
for k, v := range tokenMap {
|
||||
merged[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func buildOAuthTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any {
|
||||
fields := make(map[string]any, 5)
|
||||
if tok != nil && tok.AccessToken != "" {
|
||||
fields["access_token"] = tok.AccessToken
|
||||
}
|
||||
if tok != nil && tok.TokenType != "" {
|
||||
fields["token_type"] = tok.TokenType
|
||||
}
|
||||
if tok != nil && tok.RefreshToken != "" {
|
||||
fields["refresh_token"] = tok.RefreshToken
|
||||
}
|
||||
if tok != nil && !tok.Expiry.IsZero() {
|
||||
fields["expiry"] = tok.Expiry.Format(time.RFC3339)
|
||||
}
|
||||
if len(merged) > 0 {
|
||||
fields["token"] = cloneMap(merged)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
func tokenValueFromMetadata(metadata map[string]any) string {
|
||||
if len(metadata) == 0 {
|
||||
return ""
|
||||
}
|
||||
if v, ok := metadata["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
if tokenRaw, ok := metadata["token"]; ok && tokenRaw != nil {
|
||||
switch typed := tokenRaw.(type) {
|
||||
case string:
|
||||
if v := strings.TrimSpace(typed); v != "" {
|
||||
return v
|
||||
}
|
||||
case map[string]any:
|
||||
if v, ok := typed["access_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := typed["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
case map[string]string:
|
||||
if v := strings.TrimSpace(typed["access_token"]); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(typed["accessToken"]); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
}
|
||||
if v, ok := metadata["token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := metadata["id_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := metadata["cookie"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (h *Handler) authByIndex(authIndex string) *coreauth.Auth {
|
||||
authIndex = strings.TrimSpace(authIndex)
|
||||
if authIndex == "" || h == nil || h.authManager == nil {
|
||||
return nil
|
||||
}
|
||||
auths := h.authManager.List()
|
||||
for _, auth := range auths {
|
||||
if auth == nil {
|
||||
continue
|
||||
}
|
||||
auth.EnsureIndex()
|
||||
if auth.Index == authIndex {
|
||||
return auth
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
|
||||
var proxyCandidates []string
|
||||
if auth != nil {
|
||||
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
|
||||
proxyCandidates = append(proxyCandidates, proxyStr)
|
||||
}
|
||||
}
|
||||
if h != nil && h.cfg != nil {
|
||||
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
|
||||
proxyCandidates = append(proxyCandidates, proxyStr)
|
||||
}
|
||||
}
|
||||
|
||||
for _, proxyStr := range proxyCandidates {
|
||||
if transport := buildProxyTransport(proxyStr); transport != nil {
|
||||
return transport
|
||||
}
|
||||
}
|
||||
|
||||
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok || transport == nil {
|
||||
return &http.Transport{Proxy: nil}
|
||||
}
|
||||
clone := transport.Clone()
|
||||
clone.Proxy = nil
|
||||
return clone
|
||||
}
|
||||
|
||||
func buildProxyTransport(proxyStr string) *http.Transport {
|
||||
proxyStr = strings.TrimSpace(proxyStr)
|
||||
if proxyStr == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
proxyURL, errParse := url.Parse(proxyStr)
|
||||
if errParse != nil {
|
||||
log.WithError(errParse).Debug("parse proxy URL failed")
|
||||
return nil
|
||||
}
|
||||
if proxyURL.Scheme == "" || proxyURL.Host == "" {
|
||||
log.Debug("proxy URL missing scheme/host")
|
||||
return nil
|
||||
}
|
||||
|
||||
if proxyURL.Scheme == "socks5" {
|
||||
var proxyAuth *proxy.Auth
|
||||
if proxyURL.User != nil {
|
||||
username := proxyURL.User.Username()
|
||||
password, _ := proxyURL.User.Password()
|
||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
||||
}
|
||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
||||
if errSOCKS5 != nil {
|
||||
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
|
||||
return nil
|
||||
}
|
||||
return &http.Transport{
|
||||
Proxy: nil,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
||||
}
|
||||
|
||||
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
||||
return nil
|
||||
}
|
||||
173
internal/api/handlers/management/api_tools_test.go
Normal file
173
internal/api/handlers/management/api_tools_test.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
type memoryAuthStore struct {
|
||||
mu sync.Mutex
|
||||
items map[string]*coreauth.Auth
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) {
|
||||
_ = ctx
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]*coreauth.Auth, 0, len(s.items))
|
||||
for _, a := range s.items {
|
||||
out = append(out, a.Clone())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
||||
_ = ctx
|
||||
if auth == nil {
|
||||
return "", nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
if s.items == nil {
|
||||
s.items = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
s.items[auth.ID] = auth.Clone()
|
||||
s.mu.Unlock()
|
||||
return auth.ID, nil
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
|
||||
_ = ctx
|
||||
s.mu.Lock()
|
||||
delete(s.items, id)
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
|
||||
var callCount int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
|
||||
t.Fatalf("unexpected content-type: %s", ct)
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
values, err := url.ParseQuery(string(bodyBytes))
|
||||
if err != nil {
|
||||
t.Fatalf("parse form: %v", err)
|
||||
}
|
||||
if values.Get("grant_type") != "refresh_token" {
|
||||
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
|
||||
}
|
||||
if values.Get("refresh_token") != "rt" {
|
||||
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
|
||||
}
|
||||
if values.Get("client_id") != antigravityOAuthClientID {
|
||||
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
|
||||
}
|
||||
if values.Get("client_secret") != antigravityOAuthClientSecret {
|
||||
t.Fatalf("unexpected client_secret")
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": "new-token",
|
||||
"refresh_token": "rt2",
|
||||
"expires_in": int64(3600),
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
originalURL := antigravityOAuthTokenURL
|
||||
antigravityOAuthTokenURL = srv.URL
|
||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
||||
|
||||
store := &memoryAuthStore{}
|
||||
manager := coreauth.NewManager(store, nil, nil)
|
||||
|
||||
auth := &coreauth.Auth{
|
||||
ID: "antigravity-test.json",
|
||||
FileName: "antigravity-test.json",
|
||||
Provider: "antigravity",
|
||||
Metadata: map[string]any{
|
||||
"type": "antigravity",
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "rt",
|
||||
"expires_in": int64(3600),
|
||||
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
|
||||
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
h := &Handler{authManager: manager}
|
||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
||||
}
|
||||
if token != "new-token" {
|
||||
t.Fatalf("expected refreshed token, got %q", token)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Fatalf("expected 1 refresh call, got %d", callCount)
|
||||
}
|
||||
|
||||
updated, ok := manager.GetByID(auth.ID)
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth in manager after update")
|
||||
}
|
||||
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
|
||||
t.Fatalf("expected manager metadata updated, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
|
||||
var callCount int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
originalURL := antigravityOAuthTokenURL
|
||||
antigravityOAuthTokenURL = srv.URL
|
||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
||||
|
||||
auth := &coreauth.Auth{
|
||||
ID: "antigravity-valid.json",
|
||||
FileName: "antigravity-valid.json",
|
||||
Provider: "antigravity",
|
||||
Metadata: map[string]any{
|
||||
"type": "antigravity",
|
||||
"access_token": "ok-token",
|
||||
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
h := &Handler{}
|
||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
||||
}
|
||||
if token != "ok-token" {
|
||||
t.Fatalf("expected existing token, got %q", token)
|
||||
}
|
||||
if callCount != 0 {
|
||||
t.Fatalf("expected no refresh calls, got %d", callCount)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -202,6 +202,26 @@ func (h *Handler) PutLoggingToFile(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.LoggingToFile = v })
|
||||
}
|
||||
|
||||
// LogsMaxTotalSizeMB
|
||||
func (h *Handler) GetLogsMaxTotalSizeMB(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"logs-max-total-size-mb": h.cfg.LogsMaxTotalSizeMB})
|
||||
}
|
||||
func (h *Handler) PutLogsMaxTotalSizeMB(c *gin.Context) {
|
||||
var body struct {
|
||||
Value *int `json:"value"`
|
||||
}
|
||||
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
value := *body.Value
|
||||
if value < 0 {
|
||||
value = 0
|
||||
}
|
||||
h.cfg.LogsMaxTotalSizeMB = value
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// Request log
|
||||
func (h *Handler) GetRequestLog(c *gin.Context) { c.JSON(200, gin.H{"request-log": h.cfg.RequestLog}) }
|
||||
func (h *Handler) PutRequestLog(c *gin.Context) {
|
||||
@@ -232,6 +252,52 @@ func (h *Handler) PutMaxRetryInterval(c *gin.Context) {
|
||||
h.updateIntField(c, func(v int) { h.cfg.MaxRetryInterval = v })
|
||||
}
|
||||
|
||||
// ForceModelPrefix
|
||||
func (h *Handler) GetForceModelPrefix(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"force-model-prefix": h.cfg.ForceModelPrefix})
|
||||
}
|
||||
func (h *Handler) PutForceModelPrefix(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.ForceModelPrefix = v })
|
||||
}
|
||||
|
||||
func normalizeRoutingStrategy(strategy string) (string, bool) {
|
||||
normalized := strings.ToLower(strings.TrimSpace(strategy))
|
||||
switch normalized {
|
||||
case "", "round-robin", "roundrobin", "rr":
|
||||
return "round-robin", true
|
||||
case "fill-first", "fillfirst", "ff":
|
||||
return "fill-first", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
// RoutingStrategy
|
||||
func (h *Handler) GetRoutingStrategy(c *gin.Context) {
|
||||
strategy, ok := normalizeRoutingStrategy(h.cfg.Routing.Strategy)
|
||||
if !ok {
|
||||
c.JSON(200, gin.H{"strategy": strings.TrimSpace(h.cfg.Routing.Strategy)})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"strategy": strategy})
|
||||
}
|
||||
func (h *Handler) PutRoutingStrategy(c *gin.Context) {
|
||||
var body struct {
|
||||
Value *string `json:"value"`
|
||||
}
|
||||
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
normalized, ok := normalizeRoutingStrategy(*body.Value)
|
||||
if !ok {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid strategy"})
|
||||
return
|
||||
}
|
||||
h.cfg.Routing.Strategy = normalized
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// Proxy URL
|
||||
func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) }
|
||||
func (h *Handler) PutProxyURL(c *gin.Context) {
|
||||
|
||||
@@ -145,71 +145,74 @@ func (h *Handler) PutGeminiKeys(c *gin.Context) {
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchGeminiKey(c *gin.Context) {
|
||||
type geminiKeyPatch struct {
|
||||
APIKey *string `json:"api-key"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
ProxyURL *string `json:"proxy-url"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
ExcludedModels *[]string `json:"excluded-models"`
|
||||
}
|
||||
var body struct {
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *config.GeminiKey `json:"value"`
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *geminiKeyPatch `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
value := *body.Value
|
||||
value.APIKey = strings.TrimSpace(value.APIKey)
|
||||
value.BaseURL = strings.TrimSpace(value.BaseURL)
|
||||
value.ProxyURL = strings.TrimSpace(value.ProxyURL)
|
||||
value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels)
|
||||
if value.APIKey == "" {
|
||||
// Treat empty API key as delete.
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) {
|
||||
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:*body.Index], h.cfg.GeminiKey[*body.Index+1:]...)
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
if match != "" {
|
||||
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
|
||||
removed := false
|
||||
for i := range h.cfg.GeminiKey {
|
||||
if !removed && h.cfg.GeminiKey[i].APIKey == match {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
out = append(out, h.cfg.GeminiKey[i])
|
||||
}
|
||||
if removed {
|
||||
h.cfg.GeminiKey = out
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
targetIndex := -1
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) {
|
||||
targetIndex = *body.Index
|
||||
}
|
||||
if targetIndex == -1 && body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
if match != "" {
|
||||
for i := range h.cfg.GeminiKey {
|
||||
if h.cfg.GeminiKey[i].APIKey == match {
|
||||
targetIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if targetIndex == -1 {
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
}
|
||||
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) {
|
||||
h.cfg.GeminiKey[*body.Index] = value
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
for i := range h.cfg.GeminiKey {
|
||||
if h.cfg.GeminiKey[i].APIKey == match {
|
||||
h.cfg.GeminiKey[i] = value
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
entry := h.cfg.GeminiKey[targetIndex]
|
||||
if body.Value.APIKey != nil {
|
||||
trimmed := strings.TrimSpace(*body.Value.APIKey)
|
||||
if trimmed == "" {
|
||||
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:targetIndex], h.cfg.GeminiKey[targetIndex+1:]...)
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
entry.APIKey = trimmed
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
if body.Value.Prefix != nil {
|
||||
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
||||
}
|
||||
if body.Value.BaseURL != nil {
|
||||
entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL)
|
||||
}
|
||||
if body.Value.ProxyURL != nil {
|
||||
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
|
||||
}
|
||||
if body.Value.Headers != nil {
|
||||
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
||||
}
|
||||
if body.Value.ExcludedModels != nil {
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
|
||||
}
|
||||
h.cfg.GeminiKey[targetIndex] = entry
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteGeminiKey(c *gin.Context) {
|
||||
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
|
||||
@@ -268,35 +271,70 @@ func (h *Handler) PutClaudeKeys(c *gin.Context) {
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchClaudeKey(c *gin.Context) {
|
||||
type claudeKeyPatch struct {
|
||||
APIKey *string `json:"api-key"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
ProxyURL *string `json:"proxy-url"`
|
||||
Models *[]config.ClaudeModel `json:"models"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
ExcludedModels *[]string `json:"excluded-models"`
|
||||
}
|
||||
var body struct {
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *config.ClaudeKey `json:"value"`
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *claudeKeyPatch `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
value := *body.Value
|
||||
normalizeClaudeKey(&value)
|
||||
targetIndex := -1
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) {
|
||||
h.cfg.ClaudeKey[*body.Index] = value
|
||||
h.cfg.SanitizeClaudeKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
targetIndex = *body.Index
|
||||
}
|
||||
if body.Match != nil {
|
||||
if targetIndex == -1 && body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
for i := range h.cfg.ClaudeKey {
|
||||
if h.cfg.ClaudeKey[i].APIKey == *body.Match {
|
||||
h.cfg.ClaudeKey[i] = value
|
||||
h.cfg.SanitizeClaudeKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
if h.cfg.ClaudeKey[i].APIKey == match {
|
||||
targetIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
if targetIndex == -1 {
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
}
|
||||
|
||||
entry := h.cfg.ClaudeKey[targetIndex]
|
||||
if body.Value.APIKey != nil {
|
||||
entry.APIKey = strings.TrimSpace(*body.Value.APIKey)
|
||||
}
|
||||
if body.Value.Prefix != nil {
|
||||
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
||||
}
|
||||
if body.Value.BaseURL != nil {
|
||||
entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL)
|
||||
}
|
||||
if body.Value.ProxyURL != nil {
|
||||
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
|
||||
}
|
||||
if body.Value.Models != nil {
|
||||
entry.Models = append([]config.ClaudeModel(nil), (*body.Value.Models)...)
|
||||
}
|
||||
if body.Value.Headers != nil {
|
||||
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
||||
}
|
||||
if body.Value.ExcludedModels != nil {
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
|
||||
}
|
||||
normalizeClaudeKey(&entry)
|
||||
h.cfg.ClaudeKey[targetIndex] = entry
|
||||
h.cfg.SanitizeClaudeKeys()
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteClaudeKey(c *gin.Context) {
|
||||
if val := c.Query("api-key"); val != "" {
|
||||
out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey))
|
||||
@@ -356,62 +394,73 @@ func (h *Handler) PutOpenAICompat(c *gin.Context) {
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchOpenAICompat(c *gin.Context) {
|
||||
type openAICompatPatch struct {
|
||||
Name *string `json:"name"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
APIKeyEntries *[]config.OpenAICompatibilityAPIKey `json:"api-key-entries"`
|
||||
Models *[]config.OpenAICompatibilityModel `json:"models"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
}
|
||||
var body struct {
|
||||
Name *string `json:"name"`
|
||||
Index *int `json:"index"`
|
||||
Value *config.OpenAICompatibility `json:"value"`
|
||||
Name *string `json:"name"`
|
||||
Index *int `json:"index"`
|
||||
Value *openAICompatPatch `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
normalizeOpenAICompatibilityEntry(body.Value)
|
||||
// If base-url becomes empty, delete the provider instead of updating
|
||||
if strings.TrimSpace(body.Value.BaseURL) == "" {
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) {
|
||||
h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:*body.Index], h.cfg.OpenAICompatibility[*body.Index+1:]...)
|
||||
targetIndex := -1
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) {
|
||||
targetIndex = *body.Index
|
||||
}
|
||||
if targetIndex == -1 && body.Name != nil {
|
||||
match := strings.TrimSpace(*body.Name)
|
||||
for i := range h.cfg.OpenAICompatibility {
|
||||
if h.cfg.OpenAICompatibility[i].Name == match {
|
||||
targetIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if targetIndex == -1 {
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
}
|
||||
|
||||
entry := h.cfg.OpenAICompatibility[targetIndex]
|
||||
if body.Value.Name != nil {
|
||||
entry.Name = strings.TrimSpace(*body.Value.Name)
|
||||
}
|
||||
if body.Value.Prefix != nil {
|
||||
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
||||
}
|
||||
if body.Value.BaseURL != nil {
|
||||
trimmed := strings.TrimSpace(*body.Value.BaseURL)
|
||||
if trimmed == "" {
|
||||
h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:targetIndex], h.cfg.OpenAICompatibility[targetIndex+1:]...)
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Name != nil {
|
||||
out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility))
|
||||
removed := false
|
||||
for i := range h.cfg.OpenAICompatibility {
|
||||
if !removed && h.cfg.OpenAICompatibility[i].Name == *body.Name {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
out = append(out, h.cfg.OpenAICompatibility[i])
|
||||
}
|
||||
if removed {
|
||||
h.cfg.OpenAICompatibility = out
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
entry.BaseURL = trimmed
|
||||
}
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) {
|
||||
h.cfg.OpenAICompatibility[*body.Index] = *body.Value
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
if body.Value.APIKeyEntries != nil {
|
||||
entry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), (*body.Value.APIKeyEntries)...)
|
||||
}
|
||||
if body.Name != nil {
|
||||
for i := range h.cfg.OpenAICompatibility {
|
||||
if h.cfg.OpenAICompatibility[i].Name == *body.Name {
|
||||
h.cfg.OpenAICompatibility[i] = *body.Value
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
if body.Value.Models != nil {
|
||||
entry.Models = append([]config.OpenAICompatibilityModel(nil), (*body.Value.Models)...)
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
if body.Value.Headers != nil {
|
||||
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
||||
}
|
||||
normalizeOpenAICompatibilityEntry(&entry)
|
||||
h.cfg.OpenAICompatibility[targetIndex] = entry
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
|
||||
if name := c.Query("name"); name != "" {
|
||||
out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility))
|
||||
@@ -438,6 +487,137 @@ func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
|
||||
c.JSON(400, gin.H{"error": "missing name or index"})
|
||||
}
|
||||
|
||||
// vertex-api-key: []VertexCompatKey
|
||||
func (h *Handler) GetVertexCompatKeys(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"vertex-api-key": h.cfg.VertexCompatAPIKey})
|
||||
}
|
||||
func (h *Handler) PutVertexCompatKeys(c *gin.Context) {
|
||||
data, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"error": "failed to read body"})
|
||||
return
|
||||
}
|
||||
var arr []config.VertexCompatKey
|
||||
if err = json.Unmarshal(data, &arr); err != nil {
|
||||
var obj struct {
|
||||
Items []config.VertexCompatKey `json:"items"`
|
||||
}
|
||||
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
arr = obj.Items
|
||||
}
|
||||
for i := range arr {
|
||||
normalizeVertexCompatKey(&arr[i])
|
||||
}
|
||||
h.cfg.VertexCompatAPIKey = arr
|
||||
h.cfg.SanitizeVertexCompatKeys()
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
|
||||
type vertexCompatPatch struct {
|
||||
APIKey *string `json:"api-key"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
ProxyURL *string `json:"proxy-url"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
Models *[]config.VertexCompatModel `json:"models"`
|
||||
}
|
||||
var body struct {
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *vertexCompatPatch `json:"value"`
|
||||
}
|
||||
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
targetIndex := -1
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.VertexCompatAPIKey) {
|
||||
targetIndex = *body.Index
|
||||
}
|
||||
if targetIndex == -1 && body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
if match != "" {
|
||||
for i := range h.cfg.VertexCompatAPIKey {
|
||||
if h.cfg.VertexCompatAPIKey[i].APIKey == match {
|
||||
targetIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if targetIndex == -1 {
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
}
|
||||
|
||||
entry := h.cfg.VertexCompatAPIKey[targetIndex]
|
||||
if body.Value.APIKey != nil {
|
||||
trimmed := strings.TrimSpace(*body.Value.APIKey)
|
||||
if trimmed == "" {
|
||||
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...)
|
||||
h.cfg.SanitizeVertexCompatKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
entry.APIKey = trimmed
|
||||
}
|
||||
if body.Value.Prefix != nil {
|
||||
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
||||
}
|
||||
if body.Value.BaseURL != nil {
|
||||
trimmed := strings.TrimSpace(*body.Value.BaseURL)
|
||||
if trimmed == "" {
|
||||
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...)
|
||||
h.cfg.SanitizeVertexCompatKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
entry.BaseURL = trimmed
|
||||
}
|
||||
if body.Value.ProxyURL != nil {
|
||||
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
|
||||
}
|
||||
if body.Value.Headers != nil {
|
||||
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
||||
}
|
||||
if body.Value.Models != nil {
|
||||
entry.Models = append([]config.VertexCompatModel(nil), (*body.Value.Models)...)
|
||||
}
|
||||
normalizeVertexCompatKey(&entry)
|
||||
h.cfg.VertexCompatAPIKey[targetIndex] = entry
|
||||
h.cfg.SanitizeVertexCompatKeys()
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteVertexCompatKey(c *gin.Context) {
|
||||
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||
out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey))
|
||||
for _, v := range h.cfg.VertexCompatAPIKey {
|
||||
if v.APIKey != val {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
h.cfg.VertexCompatAPIKey = out
|
||||
h.cfg.SanitizeVertexCompatKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if idxStr := c.Query("index"); idxStr != "" {
|
||||
var idx int
|
||||
_, errScan := fmt.Sscanf(idxStr, "%d", &idx)
|
||||
if errScan == nil && idx >= 0 && idx < len(h.cfg.VertexCompatAPIKey) {
|
||||
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:idx], h.cfg.VertexCompatAPIKey[idx+1:]...)
|
||||
h.cfg.SanitizeVertexCompatKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.JSON(400, gin.H{"error": "missing api-key or index"})
|
||||
}
|
||||
|
||||
// oauth-excluded-models: map[string][]string
|
||||
func (h *Handler) GetOAuthExcludedModels(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(h.cfg.OAuthExcludedModels)})
|
||||
@@ -523,6 +703,103 @@ func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) {
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// oauth-model-alias: map[string][]OAuthModelAlias
|
||||
func (h *Handler) GetOAuthModelAlias(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"oauth-model-alias": sanitizedOAuthModelAlias(h.cfg.OAuthModelAlias)})
|
||||
}
|
||||
|
||||
func (h *Handler) PutOAuthModelAlias(c *gin.Context) {
|
||||
data, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"error": "failed to read body"})
|
||||
return
|
||||
}
|
||||
var entries map[string][]config.OAuthModelAlias
|
||||
if err = json.Unmarshal(data, &entries); err != nil {
|
||||
var wrapper struct {
|
||||
Items map[string][]config.OAuthModelAlias `json:"items"`
|
||||
}
|
||||
if err2 := json.Unmarshal(data, &wrapper); err2 != nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
entries = wrapper.Items
|
||||
}
|
||||
h.cfg.OAuthModelAlias = sanitizedOAuthModelAlias(entries)
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) PatchOAuthModelAlias(c *gin.Context) {
|
||||
var body struct {
|
||||
Provider *string `json:"provider"`
|
||||
Channel *string `json:"channel"`
|
||||
Aliases []config.OAuthModelAlias `json:"aliases"`
|
||||
}
|
||||
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
channelRaw := ""
|
||||
if body.Channel != nil {
|
||||
channelRaw = *body.Channel
|
||||
} else if body.Provider != nil {
|
||||
channelRaw = *body.Provider
|
||||
}
|
||||
channel := strings.ToLower(strings.TrimSpace(channelRaw))
|
||||
if channel == "" {
|
||||
c.JSON(400, gin.H{"error": "invalid channel"})
|
||||
return
|
||||
}
|
||||
|
||||
normalizedMap := sanitizedOAuthModelAlias(map[string][]config.OAuthModelAlias{channel: body.Aliases})
|
||||
normalized := normalizedMap[channel]
|
||||
if len(normalized) == 0 {
|
||||
if h.cfg.OAuthModelAlias == nil {
|
||||
c.JSON(404, gin.H{"error": "channel not found"})
|
||||
return
|
||||
}
|
||||
if _, ok := h.cfg.OAuthModelAlias[channel]; !ok {
|
||||
c.JSON(404, gin.H{"error": "channel not found"})
|
||||
return
|
||||
}
|
||||
delete(h.cfg.OAuthModelAlias, channel)
|
||||
if len(h.cfg.OAuthModelAlias) == 0 {
|
||||
h.cfg.OAuthModelAlias = nil
|
||||
}
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if h.cfg.OAuthModelAlias == nil {
|
||||
h.cfg.OAuthModelAlias = make(map[string][]config.OAuthModelAlias)
|
||||
}
|
||||
h.cfg.OAuthModelAlias[channel] = normalized
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) {
|
||||
channel := strings.ToLower(strings.TrimSpace(c.Query("channel")))
|
||||
if channel == "" {
|
||||
channel = strings.ToLower(strings.TrimSpace(c.Query("provider")))
|
||||
}
|
||||
if channel == "" {
|
||||
c.JSON(400, gin.H{"error": "missing channel"})
|
||||
return
|
||||
}
|
||||
if h.cfg.OAuthModelAlias == nil {
|
||||
c.JSON(404, gin.H{"error": "channel not found"})
|
||||
return
|
||||
}
|
||||
if _, ok := h.cfg.OAuthModelAlias[channel]; !ok {
|
||||
c.JSON(404, gin.H{"error": "channel not found"})
|
||||
return
|
||||
}
|
||||
delete(h.cfg.OAuthModelAlias, channel)
|
||||
if len(h.cfg.OAuthModelAlias) == 0 {
|
||||
h.cfg.OAuthModelAlias = nil
|
||||
}
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// codex-api-key: []CodexKey
|
||||
func (h *Handler) GetCodexKeys(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey})
|
||||
@@ -548,11 +825,7 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
|
||||
filtered := make([]config.CodexKey, 0, len(arr))
|
||||
for i := range arr {
|
||||
entry := arr[i]
|
||||
entry.APIKey = strings.TrimSpace(entry.APIKey)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
|
||||
normalizeCodexKey(&entry)
|
||||
if entry.BaseURL == "" {
|
||||
continue
|
||||
}
|
||||
@@ -563,66 +836,77 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchCodexKey(c *gin.Context) {
|
||||
type codexKeyPatch struct {
|
||||
APIKey *string `json:"api-key"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
ProxyURL *string `json:"proxy-url"`
|
||||
Models *[]config.CodexModel `json:"models"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
ExcludedModels *[]string `json:"excluded-models"`
|
||||
}
|
||||
var body struct {
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *config.CodexKey `json:"value"`
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *codexKeyPatch `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
value := *body.Value
|
||||
value.APIKey = strings.TrimSpace(value.APIKey)
|
||||
value.BaseURL = strings.TrimSpace(value.BaseURL)
|
||||
value.ProxyURL = strings.TrimSpace(value.ProxyURL)
|
||||
value.Headers = config.NormalizeHeaders(value.Headers)
|
||||
value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels)
|
||||
// If base-url becomes empty, delete instead of update
|
||||
if value.BaseURL == "" {
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
|
||||
h.cfg.CodexKey = append(h.cfg.CodexKey[:*body.Index], h.cfg.CodexKey[*body.Index+1:]...)
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Match != nil {
|
||||
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
|
||||
removed := false
|
||||
for i := range h.cfg.CodexKey {
|
||||
if !removed && h.cfg.CodexKey[i].APIKey == *body.Match {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
out = append(out, h.cfg.CodexKey[i])
|
||||
}
|
||||
if removed {
|
||||
h.cfg.CodexKey = out
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
|
||||
h.cfg.CodexKey[*body.Index] = value
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Match != nil {
|
||||
for i := range h.cfg.CodexKey {
|
||||
if h.cfg.CodexKey[i].APIKey == *body.Match {
|
||||
h.cfg.CodexKey[i] = value
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
targetIndex := -1
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
|
||||
targetIndex = *body.Index
|
||||
}
|
||||
if targetIndex == -1 && body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
for i := range h.cfg.CodexKey {
|
||||
if h.cfg.CodexKey[i].APIKey == match {
|
||||
targetIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
if targetIndex == -1 {
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
}
|
||||
|
||||
entry := h.cfg.CodexKey[targetIndex]
|
||||
if body.Value.APIKey != nil {
|
||||
entry.APIKey = strings.TrimSpace(*body.Value.APIKey)
|
||||
}
|
||||
if body.Value.Prefix != nil {
|
||||
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
||||
}
|
||||
if body.Value.BaseURL != nil {
|
||||
trimmed := strings.TrimSpace(*body.Value.BaseURL)
|
||||
if trimmed == "" {
|
||||
h.cfg.CodexKey = append(h.cfg.CodexKey[:targetIndex], h.cfg.CodexKey[targetIndex+1:]...)
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
entry.BaseURL = trimmed
|
||||
}
|
||||
if body.Value.ProxyURL != nil {
|
||||
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
|
||||
}
|
||||
if body.Value.Models != nil {
|
||||
entry.Models = append([]config.CodexModel(nil), (*body.Value.Models)...)
|
||||
}
|
||||
if body.Value.Headers != nil {
|
||||
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
||||
}
|
||||
if body.Value.ExcludedModels != nil {
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
|
||||
}
|
||||
normalizeCodexKey(&entry)
|
||||
h.cfg.CodexKey[targetIndex] = entry
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteCodexKey(c *gin.Context) {
|
||||
if val := c.Query("api-key"); val != "" {
|
||||
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
|
||||
@@ -707,6 +991,79 @@ func normalizeClaudeKey(entry *config.ClaudeKey) {
|
||||
entry.Models = normalized
|
||||
}
|
||||
|
||||
func normalizeCodexKey(entry *config.CodexKey) {
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
entry.APIKey = strings.TrimSpace(entry.APIKey)
|
||||
entry.Prefix = strings.TrimSpace(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
|
||||
if len(entry.Models) == 0 {
|
||||
return
|
||||
}
|
||||
normalized := make([]config.CodexModel, 0, len(entry.Models))
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
model.Name = strings.TrimSpace(model.Name)
|
||||
model.Alias = strings.TrimSpace(model.Alias)
|
||||
if model.Name == "" && model.Alias == "" {
|
||||
continue
|
||||
}
|
||||
normalized = append(normalized, model)
|
||||
}
|
||||
entry.Models = normalized
|
||||
}
|
||||
|
||||
func normalizeVertexCompatKey(entry *config.VertexCompatKey) {
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
entry.APIKey = strings.TrimSpace(entry.APIKey)
|
||||
entry.Prefix = strings.TrimSpace(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
||||
if len(entry.Models) == 0 {
|
||||
return
|
||||
}
|
||||
normalized := make([]config.VertexCompatModel, 0, len(entry.Models))
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
model.Name = strings.TrimSpace(model.Name)
|
||||
model.Alias = strings.TrimSpace(model.Alias)
|
||||
if model.Name == "" || model.Alias == "" {
|
||||
continue
|
||||
}
|
||||
normalized = append(normalized, model)
|
||||
}
|
||||
entry.Models = normalized
|
||||
}
|
||||
|
||||
func sanitizedOAuthModelAlias(entries map[string][]config.OAuthModelAlias) map[string][]config.OAuthModelAlias {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
copied := make(map[string][]config.OAuthModelAlias, len(entries))
|
||||
for channel, aliases := range entries {
|
||||
if len(aliases) == 0 {
|
||||
continue
|
||||
}
|
||||
copied[channel] = append([]config.OAuthModelAlias(nil), aliases...)
|
||||
}
|
||||
if len(copied) == 0 {
|
||||
return nil
|
||||
}
|
||||
cfg := config.Config{OAuthModelAlias: copied}
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
if len(cfg.OAuthModelAlias) == 0 {
|
||||
return nil
|
||||
}
|
||||
return cfg.OAuthModelAlias
|
||||
}
|
||||
|
||||
// GetAmpCode returns the complete ampcode configuration.
|
||||
func (h *Handler) GetAmpCode(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
@@ -858,3 +1215,151 @@ func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
|
||||
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
|
||||
}
|
||||
|
||||
// GetAmpUpstreamAPIKeys returns the ampcode upstream API keys mapping.
|
||||
func (h *Handler) GetAmpUpstreamAPIKeys(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(200, gin.H{"upstream-api-keys": []config.AmpUpstreamAPIKeyEntry{}})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"upstream-api-keys": h.cfg.AmpCode.UpstreamAPIKeys})
|
||||
}
|
||||
|
||||
// PutAmpUpstreamAPIKeys replaces all ampcode upstream API keys mappings.
|
||||
func (h *Handler) PutAmpUpstreamAPIKeys(c *gin.Context) {
|
||||
var body struct {
|
||||
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
// Normalize entries: trim whitespace, filter empty
|
||||
normalized := normalizeAmpUpstreamAPIKeyEntries(body.Value)
|
||||
h.cfg.AmpCode.UpstreamAPIKeys = normalized
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// PatchAmpUpstreamAPIKeys adds or updates upstream API keys entries.
|
||||
// Matching is done by upstream-api-key value.
|
||||
func (h *Handler) PatchAmpUpstreamAPIKeys(c *gin.Context) {
|
||||
var body struct {
|
||||
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
existing := make(map[string]int)
|
||||
for i, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
|
||||
existing[strings.TrimSpace(entry.UpstreamAPIKey)] = i
|
||||
}
|
||||
|
||||
for _, newEntry := range body.Value {
|
||||
upstreamKey := strings.TrimSpace(newEntry.UpstreamAPIKey)
|
||||
if upstreamKey == "" {
|
||||
continue
|
||||
}
|
||||
normalizedEntry := config.AmpUpstreamAPIKeyEntry{
|
||||
UpstreamAPIKey: upstreamKey,
|
||||
APIKeys: normalizeAPIKeysList(newEntry.APIKeys),
|
||||
}
|
||||
if idx, ok := existing[upstreamKey]; ok {
|
||||
h.cfg.AmpCode.UpstreamAPIKeys[idx] = normalizedEntry
|
||||
} else {
|
||||
h.cfg.AmpCode.UpstreamAPIKeys = append(h.cfg.AmpCode.UpstreamAPIKeys, normalizedEntry)
|
||||
existing[upstreamKey] = len(h.cfg.AmpCode.UpstreamAPIKeys) - 1
|
||||
}
|
||||
}
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// DeleteAmpUpstreamAPIKeys removes specified upstream API keys entries.
|
||||
// Body must be JSON: {"value": ["<upstream-api-key>", ...]}.
|
||||
// If "value" is an empty array, clears all entries.
|
||||
// If JSON is invalid or "value" is missing/null, returns 400 and does not persist any change.
|
||||
func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) {
|
||||
var body struct {
|
||||
Value []string `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
if body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "missing value"})
|
||||
return
|
||||
}
|
||||
|
||||
// Empty array means clear all
|
||||
if len(body.Value) == 0 {
|
||||
h.cfg.AmpCode.UpstreamAPIKeys = nil
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
|
||||
toRemove := make(map[string]bool)
|
||||
for _, key := range body.Value {
|
||||
trimmed := strings.TrimSpace(key)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
toRemove[trimmed] = true
|
||||
}
|
||||
if len(toRemove) == 0 {
|
||||
c.JSON(400, gin.H{"error": "empty value"})
|
||||
return
|
||||
}
|
||||
|
||||
newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(h.cfg.AmpCode.UpstreamAPIKeys))
|
||||
for _, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
|
||||
if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] {
|
||||
newEntries = append(newEntries, entry)
|
||||
}
|
||||
}
|
||||
h.cfg.AmpCode.UpstreamAPIKeys = newEntries
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// normalizeAmpUpstreamAPIKeyEntries normalizes a list of upstream API key entries.
|
||||
func normalizeAmpUpstreamAPIKeyEntries(entries []config.AmpUpstreamAPIKeyEntry) []config.AmpUpstreamAPIKeyEntry {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]config.AmpUpstreamAPIKeyEntry, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
|
||||
if upstreamKey == "" {
|
||||
continue
|
||||
}
|
||||
apiKeys := normalizeAPIKeysList(entry.APIKeys)
|
||||
out = append(out, config.AmpUpstreamAPIKeyEntry{
|
||||
UpstreamAPIKey: upstreamKey,
|
||||
APIKeys: apiKeys,
|
||||
})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// normalizeAPIKeysList trims and filters empty strings from a list of API keys.
|
||||
func normalizeAPIKeysList(keys []string) []string {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
trimmed := strings.TrimSpace(k)
|
||||
if trimmed != "" {
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -24,8 +24,15 @@ import (
|
||||
type attemptInfo struct {
|
||||
count int
|
||||
blockedUntil time.Time
|
||||
lastActivity time.Time // track last activity for cleanup
|
||||
}
|
||||
|
||||
// attemptCleanupInterval controls how often stale IP entries are purged
|
||||
const attemptCleanupInterval = 1 * time.Hour
|
||||
|
||||
// attemptMaxIdleTime controls how long an IP can be idle before cleanup
|
||||
const attemptMaxIdleTime = 2 * time.Hour
|
||||
|
||||
// Handler aggregates config reference, persistence path and helpers.
|
||||
type Handler struct {
|
||||
cfg *config.Config
|
||||
@@ -47,7 +54,7 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
|
||||
envSecret, _ := os.LookupEnv("MANAGEMENT_PASSWORD")
|
||||
envSecret = strings.TrimSpace(envSecret)
|
||||
|
||||
return &Handler{
|
||||
h := &Handler{
|
||||
cfg: cfg,
|
||||
configFilePath: configFilePath,
|
||||
failedAttempts: make(map[string]*attemptInfo),
|
||||
@@ -57,6 +64,43 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
|
||||
allowRemoteOverride: envSecret != "",
|
||||
envSecret: envSecret,
|
||||
}
|
||||
h.startAttemptCleanup()
|
||||
return h
|
||||
}
|
||||
|
||||
// startAttemptCleanup launches a background goroutine that periodically
|
||||
// removes stale IP entries from failedAttempts to prevent memory leaks.
|
||||
func (h *Handler) startAttemptCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(attemptCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
h.purgeStaleAttempts()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// purgeStaleAttempts removes IP entries that have been idle beyond attemptMaxIdleTime
|
||||
// and whose ban (if any) has expired.
|
||||
func (h *Handler) purgeStaleAttempts() {
|
||||
now := time.Now()
|
||||
h.attemptsMu.Lock()
|
||||
defer h.attemptsMu.Unlock()
|
||||
for ip, ai := range h.failedAttempts {
|
||||
// Skip if still banned
|
||||
if !ai.blockedUntil.IsZero() && now.Before(ai.blockedUntil) {
|
||||
continue
|
||||
}
|
||||
// Remove if idle too long
|
||||
if now.Sub(ai.lastActivity) > attemptMaxIdleTime {
|
||||
delete(h.failedAttempts, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewHandler creates a new management handler instance.
|
||||
func NewHandlerWithoutConfigFilePath(cfg *config.Config, manager *coreauth.Manager) *Handler {
|
||||
return NewHandler(cfg, "", manager)
|
||||
}
|
||||
|
||||
// SetConfig updates the in-memory config reference when the server hot-reloads.
|
||||
@@ -144,6 +188,7 @@ func (h *Handler) Middleware() gin.HandlerFunc {
|
||||
h.failedAttempts[clientIP] = aip
|
||||
}
|
||||
aip.count++
|
||||
aip.lastActivity = time.Now()
|
||||
if aip.count >= maxFailures {
|
||||
aip.blockedUntil = time.Now().Add(banDuration)
|
||||
aip.count = 0
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -209,6 +209,94 @@ func (h *Handler) GetRequestErrorLogs(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"files": files})
|
||||
}
|
||||
|
||||
// GetRequestLogByID finds and downloads a request log file by its request ID.
|
||||
// The ID is matched against the suffix of log file names (format: *-{requestID}.log).
|
||||
func (h *Handler) GetRequestLogByID(c *gin.Context) {
|
||||
if h == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
||||
return
|
||||
}
|
||||
if h.cfg == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
dir := h.logDirectory()
|
||||
if strings.TrimSpace(dir) == "" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
requestID := strings.TrimSpace(c.Param("id"))
|
||||
if requestID == "" {
|
||||
requestID = strings.TrimSpace(c.Query("id"))
|
||||
}
|
||||
if requestID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing request ID"})
|
||||
return
|
||||
}
|
||||
if strings.ContainsAny(requestID, "/\\") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
suffix := "-" + requestID + ".log"
|
||||
var matchedFile string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if strings.HasSuffix(name, suffix) {
|
||||
matchedFile = name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if matchedFile == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found for the given request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
dirAbs, errAbs := filepath.Abs(dir)
|
||||
if errAbs != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
|
||||
return
|
||||
}
|
||||
fullPath := filepath.Clean(filepath.Join(dirAbs, matchedFile))
|
||||
prefix := dirAbs + string(os.PathSeparator)
|
||||
if !strings.HasPrefix(fullPath, prefix) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
|
||||
return
|
||||
}
|
||||
|
||||
info, errStat := os.Stat(fullPath)
|
||||
if errStat != nil {
|
||||
if os.IsNotExist(errStat) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
|
||||
return
|
||||
}
|
||||
if info.IsDir() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
|
||||
return
|
||||
}
|
||||
|
||||
c.FileAttachment(fullPath, matchedFile)
|
||||
}
|
||||
|
||||
// DownloadRequestErrorLog downloads a specific error request log file by name.
|
||||
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
|
||||
if h == nil {
|
||||
@@ -272,16 +360,7 @@ func (h *Handler) logDirectory() string {
|
||||
if h.logDir != "" {
|
||||
return h.logDir
|
||||
}
|
||||
if base := util.WritablePath(); base != "" {
|
||||
return filepath.Join(base, "logs")
|
||||
}
|
||||
if h.configFilePath != "" {
|
||||
dir := filepath.Dir(h.configFilePath)
|
||||
if dir != "" && dir != "." {
|
||||
return filepath.Join(dir, "logs")
|
||||
}
|
||||
}
|
||||
return "logs"
|
||||
return logging.ResolveLogDirectory(h.cfg)
|
||||
}
|
||||
|
||||
func (h *Handler) collectLogFiles(dir string) ([]string, error) {
|
||||
|
||||
100
internal/api/handlers/management/oauth_callback.go
Normal file
100
internal/api/handlers/management/oauth_callback.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type oauthCallbackRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
RedirectURL string `json:"redirect_url"`
|
||||
Code string `json:"code"`
|
||||
State string `json:"state"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func (h *Handler) PostOAuthCallback(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "handler not initialized"})
|
||||
return
|
||||
}
|
||||
|
||||
var req oauthCallbackRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
canonicalProvider, err := NormalizeOAuthProvider(req.Provider)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "unsupported provider"})
|
||||
return
|
||||
}
|
||||
|
||||
state := strings.TrimSpace(req.State)
|
||||
code := strings.TrimSpace(req.Code)
|
||||
errMsg := strings.TrimSpace(req.Error)
|
||||
|
||||
if rawRedirect := strings.TrimSpace(req.RedirectURL); rawRedirect != "" {
|
||||
u, errParse := url.Parse(rawRedirect)
|
||||
if errParse != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid redirect_url"})
|
||||
return
|
||||
}
|
||||
q := u.Query()
|
||||
if state == "" {
|
||||
state = strings.TrimSpace(q.Get("state"))
|
||||
}
|
||||
if code == "" {
|
||||
code = strings.TrimSpace(q.Get("code"))
|
||||
}
|
||||
if errMsg == "" {
|
||||
errMsg = strings.TrimSpace(q.Get("error"))
|
||||
if errMsg == "" {
|
||||
errMsg = strings.TrimSpace(q.Get("error_description"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"})
|
||||
return
|
||||
}
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
|
||||
return
|
||||
}
|
||||
if code == "" && errMsg == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "code or error is required"})
|
||||
return
|
||||
}
|
||||
|
||||
sessionProvider, sessionStatus, ok := GetOAuthSession(state)
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"status": "error", "error": "unknown or expired state"})
|
||||
return
|
||||
}
|
||||
if sessionStatus != "" {
|
||||
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(sessionProvider, canonicalProvider) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "provider does not match state"})
|
||||
return
|
||||
}
|
||||
|
||||
if _, errWrite := WriteOAuthCallbackFileForPendingSession(h.cfg.AuthDir, canonicalProvider, state, code, errMsg); errWrite != nil {
|
||||
if errors.Is(errWrite, errOAuthSessionNotPending) {
|
||||
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to persist oauth callback"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
}
|
||||
292
internal/api/handlers/management/oauth_sessions.go
Normal file
292
internal/api/handlers/management/oauth_sessions.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
oauthSessionTTL = 10 * time.Minute
|
||||
maxOAuthStateLength = 128
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidOAuthState = errors.New("invalid oauth state")
|
||||
errUnsupportedOAuthFlow = errors.New("unsupported oauth provider")
|
||||
errOAuthSessionNotPending = errors.New("oauth session is not pending")
|
||||
)
|
||||
|
||||
type oauthSession struct {
|
||||
Provider string
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type oauthSessionStore struct {
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
sessions map[string]oauthSession
|
||||
}
|
||||
|
||||
func newOAuthSessionStore(ttl time.Duration) *oauthSessionStore {
|
||||
if ttl <= 0 {
|
||||
ttl = oauthSessionTTL
|
||||
}
|
||||
return &oauthSessionStore{
|
||||
ttl: ttl,
|
||||
sessions: make(map[string]oauthSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) purgeExpiredLocked(now time.Time) {
|
||||
for state, session := range s.sessions {
|
||||
if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
|
||||
delete(s.sessions, state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Register(state, provider string) {
|
||||
state = strings.TrimSpace(state)
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if state == "" || provider == "" {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
s.sessions[state] = oauthSession{
|
||||
Provider: provider,
|
||||
Status: "",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(s.ttl),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) SetError(state, message string) {
|
||||
state = strings.TrimSpace(state)
|
||||
message = strings.TrimSpace(message)
|
||||
if state == "" {
|
||||
return
|
||||
}
|
||||
if message == "" {
|
||||
message = "Authentication failed"
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
session.Status = message
|
||||
session.ExpiresAt = now.Add(s.ttl)
|
||||
s.sessions[state] = session
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Complete(state string) {
|
||||
state = strings.TrimSpace(state)
|
||||
if state == "" {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
delete(s.sessions, state)
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) CompleteProvider(provider string) int {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if provider == "" {
|
||||
return 0
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
removed := 0
|
||||
for state, session := range s.sessions {
|
||||
if strings.EqualFold(session.Provider, provider) {
|
||||
delete(s.sessions, state)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
return removed
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Get(state string) (oauthSession, bool) {
|
||||
state = strings.TrimSpace(state)
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
return session, ok
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) IsPending(state, provider string) bool {
|
||||
state = strings.TrimSpace(state)
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if session.Status != "" {
|
||||
if !strings.EqualFold(session.Provider, "kiro") {
|
||||
return false
|
||||
}
|
||||
if !strings.HasPrefix(session.Status, "device_code|") && !strings.HasPrefix(session.Status, "auth_url|") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if provider == "" {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(session.Provider, provider)
|
||||
}
|
||||
|
||||
var oauthSessions = newOAuthSessionStore(oauthSessionTTL)
|
||||
|
||||
func RegisterOAuthSession(state, provider string) { oauthSessions.Register(state, provider) }
|
||||
|
||||
func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, message) }
|
||||
|
||||
func CompleteOAuthSession(state string) { oauthSessions.Complete(state) }
|
||||
|
||||
func CompleteOAuthSessionsByProvider(provider string) int {
|
||||
return oauthSessions.CompleteProvider(provider)
|
||||
}
|
||||
|
||||
func GetOAuthSession(state string) (provider string, status string, ok bool) {
|
||||
session, ok := oauthSessions.Get(state)
|
||||
if !ok {
|
||||
return "", "", false
|
||||
}
|
||||
return session.Provider, session.Status, true
|
||||
}
|
||||
|
||||
func IsOAuthSessionPending(state, provider string) bool {
|
||||
return oauthSessions.IsPending(state, provider)
|
||||
}
|
||||
|
||||
func ValidateOAuthState(state string) error {
|
||||
trimmed := strings.TrimSpace(state)
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("%w: empty", errInvalidOAuthState)
|
||||
}
|
||||
if len(trimmed) > maxOAuthStateLength {
|
||||
return fmt.Errorf("%w: too long", errInvalidOAuthState)
|
||||
}
|
||||
if strings.Contains(trimmed, "/") || strings.Contains(trimmed, "\\") {
|
||||
return fmt.Errorf("%w: contains path separator", errInvalidOAuthState)
|
||||
}
|
||||
if strings.Contains(trimmed, "..") {
|
||||
return fmt.Errorf("%w: contains '..'", errInvalidOAuthState)
|
||||
}
|
||||
for _, r := range trimmed {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r >= '0' && r <= '9':
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
default:
|
||||
return fmt.Errorf("%w: invalid character", errInvalidOAuthState)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NormalizeOAuthProvider(provider string) (string, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "anthropic", "claude":
|
||||
return "anthropic", nil
|
||||
case "codex", "openai":
|
||||
return "codex", nil
|
||||
case "gemini", "google":
|
||||
return "gemini", nil
|
||||
case "iflow", "i-flow":
|
||||
return "iflow", nil
|
||||
case "antigravity", "anti-gravity":
|
||||
return "antigravity", nil
|
||||
case "qwen":
|
||||
return "qwen", nil
|
||||
case "kiro":
|
||||
return "kiro", nil
|
||||
case "github":
|
||||
return "github", nil
|
||||
default:
|
||||
return "", errUnsupportedOAuthFlow
|
||||
}
|
||||
}
|
||||
|
||||
type oauthCallbackFilePayload struct {
|
||||
Code string `json:"code"`
|
||||
State string `json:"state"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) {
|
||||
if strings.TrimSpace(authDir) == "" {
|
||||
return "", fmt.Errorf("auth dir is empty")
|
||||
}
|
||||
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf(".oauth-%s-%s.oauth", canonicalProvider, state)
|
||||
filePath := filepath.Join(authDir, fileName)
|
||||
payload := oauthCallbackFilePayload{
|
||||
Code: strings.TrimSpace(code),
|
||||
State: strings.TrimSpace(state),
|
||||
Error: strings.TrimSpace(errorMessage),
|
||||
}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal oauth callback payload: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(filePath, data, 0o600); err != nil {
|
||||
return "", fmt.Errorf("write oauth callback file: %w", err)
|
||||
}
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) {
|
||||
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !IsOAuthSessionPending(state, canonicalProvider) {
|
||||
return "", errOAuthSessionNotPending
|
||||
}
|
||||
return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage)
|
||||
}
|
||||
@@ -1,12 +1,25 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
)
|
||||
|
||||
type usageExportPayload struct {
|
||||
Version int `json:"version"`
|
||||
ExportedAt time.Time `json:"exported_at"`
|
||||
Usage usage.StatisticsSnapshot `json:"usage"`
|
||||
}
|
||||
|
||||
type usageImportPayload struct {
|
||||
Version int `json:"version"`
|
||||
Usage usage.StatisticsSnapshot `json:"usage"`
|
||||
}
|
||||
|
||||
// GetUsageStatistics returns the in-memory request statistics snapshot.
|
||||
func (h *Handler) GetUsageStatistics(c *gin.Context) {
|
||||
var snapshot usage.StatisticsSnapshot
|
||||
@@ -18,3 +31,49 @@ func (h *Handler) GetUsageStatistics(c *gin.Context) {
|
||||
"failed_requests": snapshot.FailureCount,
|
||||
})
|
||||
}
|
||||
|
||||
// ExportUsageStatistics returns a complete usage snapshot for backup/migration.
|
||||
func (h *Handler) ExportUsageStatistics(c *gin.Context) {
|
||||
var snapshot usage.StatisticsSnapshot
|
||||
if h != nil && h.usageStats != nil {
|
||||
snapshot = h.usageStats.Snapshot()
|
||||
}
|
||||
c.JSON(http.StatusOK, usageExportPayload{
|
||||
Version: 1,
|
||||
ExportedAt: time.Now().UTC(),
|
||||
Usage: snapshot,
|
||||
})
|
||||
}
|
||||
|
||||
// ImportUsageStatistics merges a previously exported usage snapshot into memory.
|
||||
func (h *Handler) ImportUsageStatistics(c *gin.Context) {
|
||||
if h == nil || h.usageStats == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
data, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
|
||||
return
|
||||
}
|
||||
|
||||
var payload usageImportPayload
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
|
||||
return
|
||||
}
|
||||
if payload.Version != 0 && payload.Version != 1 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
|
||||
return
|
||||
}
|
||||
|
||||
result := h.usageStats.MergeSnapshot(payload.Usage)
|
||||
snapshot := h.usageStats.Snapshot()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"added": result.Added,
|
||||
"skipped": result.Skipped,
|
||||
"total_requests": snapshot.TotalRequests,
|
||||
"failed_requests": snapshot.FailureCount,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -98,10 +98,11 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
}
|
||||
|
||||
return &RequestInfo{
|
||||
URL: url,
|
||||
Method: method,
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
URL: url,
|
||||
Method: method,
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
RequestID: logging.GetGinRequestID(c),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,10 +15,11 @@ import (
|
||||
|
||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||
type RequestInfo struct {
|
||||
URL string // URL is the request URL.
|
||||
Method string // Method is the HTTP method (e.g., GET, POST).
|
||||
Headers map[string][]string // Headers contains the request headers.
|
||||
Body []byte // Body is the raw request body.
|
||||
URL string // URL is the request URL.
|
||||
Method string // Method is the HTTP method (e.g., GET, POST).
|
||||
Headers map[string][]string // Headers contains the request headers.
|
||||
Body []byte // Body is the raw request body.
|
||||
RequestID string // RequestID is the unique identifier for the request.
|
||||
}
|
||||
|
||||
// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
|
||||
@@ -71,22 +72,64 @@ func (w *ResponseWriterWrapper) Write(data []byte) (int, error) {
|
||||
n, err := w.ResponseWriter.Write(data)
|
||||
|
||||
// THEN: Handle logging based on response type
|
||||
if w.isStreaming {
|
||||
if w.isStreaming && w.chunkChannel != nil {
|
||||
// For streaming responses: Send to async logging channel (non-blocking)
|
||||
if w.chunkChannel != nil {
|
||||
select {
|
||||
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
|
||||
default: // Channel full, skip logging to avoid blocking
|
||||
}
|
||||
select {
|
||||
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
|
||||
default: // Channel full, skip logging to avoid blocking
|
||||
}
|
||||
} else {
|
||||
// For non-streaming responses: Buffer complete response
|
||||
return n, err
|
||||
}
|
||||
|
||||
if w.shouldBufferResponseBody() {
|
||||
w.body.Write(data)
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) shouldBufferResponseBody() bool {
|
||||
if w.logger != nil && w.logger.IsEnabled() {
|
||||
return true
|
||||
}
|
||||
if !w.logOnErrorOnly {
|
||||
return false
|
||||
}
|
||||
status := w.statusCode
|
||||
if status == 0 {
|
||||
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok && statusWriter != nil {
|
||||
status = statusWriter.Status()
|
||||
} else {
|
||||
status = http.StatusOK
|
||||
}
|
||||
}
|
||||
return status >= http.StatusBadRequest
|
||||
}
|
||||
|
||||
// WriteString wraps the underlying ResponseWriter's WriteString method to capture response data.
|
||||
// Some handlers (and fmt/io helpers) write via io.StringWriter; without this override, those writes
|
||||
// bypass Write() and would be missing from request logs.
|
||||
func (w *ResponseWriterWrapper) WriteString(data string) (int, error) {
|
||||
w.ensureHeadersCaptured()
|
||||
|
||||
// CRITICAL: Write to client first (zero latency)
|
||||
n, err := w.ResponseWriter.WriteString(data)
|
||||
|
||||
// THEN: Capture for logging
|
||||
if w.isStreaming && w.chunkChannel != nil {
|
||||
select {
|
||||
case w.chunkChannel <- []byte(data):
|
||||
default:
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
if w.shouldBufferResponseBody() {
|
||||
w.body.WriteString(data)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// WriteHeader wraps the underlying ResponseWriter's WriteHeader method.
|
||||
// It captures the status code, detects if the response is streaming based on the Content-Type header,
|
||||
// and initializes the appropriate logging mechanism (standard or streaming).
|
||||
@@ -107,6 +150,7 @@ func (w *ResponseWriterWrapper) WriteHeader(statusCode int) {
|
||||
w.requestInfo.Method,
|
||||
w.requestInfo.Headers,
|
||||
w.requestInfo.Body,
|
||||
w.requestInfo.RequestID,
|
||||
)
|
||||
if err == nil {
|
||||
w.streamWriter = streamWriter
|
||||
@@ -160,12 +204,16 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check request body for streaming indicators
|
||||
if w.requestInfo.Body != nil {
|
||||
// If a concrete Content-Type is already set (e.g., application/json for error responses),
|
||||
// treat it as non-streaming instead of inferring from the request payload.
|
||||
if strings.TrimSpace(contentType) != "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only fall back to request payload hints when Content-Type is not set yet.
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
bodyStr := string(w.requestInfo.Body)
|
||||
if strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`) {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
|
||||
}
|
||||
|
||||
return false
|
||||
@@ -221,7 +269,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if w.isStreaming {
|
||||
if w.isStreaming && w.streamWriter != nil {
|
||||
if w.chunkChannel != nil {
|
||||
close(w.chunkChannel)
|
||||
w.chunkChannel = nil
|
||||
@@ -233,24 +281,19 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
}
|
||||
|
||||
// Write API Request and Response to the streaming log before closing
|
||||
if w.streamWriter != nil {
|
||||
apiRequest := w.extractAPIRequest(c)
|
||||
if len(apiRequest) > 0 {
|
||||
_ = w.streamWriter.WriteAPIRequest(apiRequest)
|
||||
}
|
||||
apiResponse := w.extractAPIResponse(c)
|
||||
if len(apiResponse) > 0 {
|
||||
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
||||
}
|
||||
if err := w.streamWriter.Close(); err != nil {
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
}
|
||||
apiRequest := w.extractAPIRequest(c)
|
||||
if len(apiRequest) > 0 {
|
||||
_ = w.streamWriter.WriteAPIRequest(apiRequest)
|
||||
}
|
||||
apiResponse := w.extractAPIResponse(c)
|
||||
if len(apiResponse) > 0 {
|
||||
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
||||
}
|
||||
if err := w.streamWriter.Close(); err != nil {
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
}
|
||||
if forceLog {
|
||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
w.streamWriter = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -305,7 +348,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
||||
}
|
||||
|
||||
if loggerWithOptions, ok := w.logger.(interface {
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool) error
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string) error
|
||||
}); ok {
|
||||
return loggerWithOptions.LogRequestWithOptions(
|
||||
w.requestInfo.URL,
|
||||
@@ -319,6 +362,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
||||
apiResponseBody,
|
||||
apiResponseErrors,
|
||||
forceLog,
|
||||
w.requestInfo.RequestID,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -333,28 +377,6 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
||||
apiRequestBody,
|
||||
apiResponseBody,
|
||||
apiResponseErrors,
|
||||
w.requestInfo.RequestID,
|
||||
)
|
||||
}
|
||||
|
||||
// Status returns the HTTP response status code captured by the wrapper.
|
||||
// It defaults to 200 if WriteHeader has not been called.
|
||||
func (w *ResponseWriterWrapper) Status() int {
|
||||
if w.statusCode == 0 {
|
||||
return 200 // Default status code
|
||||
}
|
||||
return w.statusCode
|
||||
}
|
||||
|
||||
// Size returns the size of the response body in bytes for non-streaming responses.
|
||||
// For streaming responses, it returns -1, as the total size is unknown.
|
||||
func (w *ResponseWriterWrapper) Size() int {
|
||||
if w.isStreaming {
|
||||
return -1 // Unknown size for streaming responses
|
||||
}
|
||||
return w.body.Len()
|
||||
}
|
||||
|
||||
// Written returns true if the response header has been written (i.e., a status code has been set).
|
||||
func (w *ResponseWriterWrapper) Written() bool {
|
||||
return w.statusCode != 0
|
||||
}
|
||||
|
||||
@@ -137,7 +137,8 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
|
||||
|
||||
// Register management proxy routes once; middleware will gate access when upstream is unavailable.
|
||||
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler)
|
||||
// Pass auth middleware to require valid API key for all management routes.
|
||||
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, auth)
|
||||
|
||||
// If no upstream URL, skip proxy routes but provider aliases are still available
|
||||
if upstreamURL == "" {
|
||||
@@ -187,9 +188,6 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
|
||||
if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
|
||||
m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost)
|
||||
if !newSettings.RestrictManagementToLocalhost {
|
||||
log.Warnf("amp management routes now accessible from any IP - this is insecure!")
|
||||
}
|
||||
}
|
||||
|
||||
newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
|
||||
@@ -229,11 +227,20 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check API key change
|
||||
// Check API key change (both default and per-client mappings)
|
||||
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
|
||||
if apiKeyChanged {
|
||||
upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings)
|
||||
if apiKeyChanged || upstreamAPIKeysChanged {
|
||||
if m.secretSource != nil {
|
||||
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||
if ms, ok := m.secretSource.(*MappedSecretSource); ok {
|
||||
if apiKeyChanged {
|
||||
ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
}
|
||||
if upstreamAPIKeysChanged {
|
||||
ms.UpdateMappings(newSettings.UpstreamAPIKeys)
|
||||
}
|
||||
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
}
|
||||
@@ -253,10 +260,22 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
|
||||
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
|
||||
if m.secretSource == nil {
|
||||
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
|
||||
// Create MultiSourceSecret as the default source, then wrap with MappedSecretSource
|
||||
defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
|
||||
mappedSource := NewMappedSecretSource(defaultSource)
|
||||
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
|
||||
m.secretSource = mappedSource
|
||||
} else if ms, ok := m.secretSource.(*MappedSecretSource); ok {
|
||||
ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
ms.UpdateMappings(settings.UpstreamAPIKeys)
|
||||
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||
// Legacy path: wrap existing MultiSourceSecret with MappedSecretSource
|
||||
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
mappedSource := NewMappedSecretSource(ms)
|
||||
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
|
||||
m.secretSource = mappedSource
|
||||
}
|
||||
|
||||
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
|
||||
@@ -281,16 +300,23 @@ func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.Amp
|
||||
return true
|
||||
}
|
||||
|
||||
// Build map for efficient comparison
|
||||
oldMap := make(map[string]string, len(old.ModelMappings))
|
||||
// Build map for efficient and robust comparison
|
||||
type mappingInfo struct {
|
||||
to string
|
||||
regex bool
|
||||
}
|
||||
oldMap := make(map[string]mappingInfo, len(old.ModelMappings))
|
||||
for _, mapping := range old.ModelMappings {
|
||||
oldMap[strings.TrimSpace(mapping.From)] = strings.TrimSpace(mapping.To)
|
||||
oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{
|
||||
to: strings.TrimSpace(mapping.To),
|
||||
regex: mapping.Regex,
|
||||
}
|
||||
}
|
||||
|
||||
for _, mapping := range new.ModelMappings {
|
||||
from := strings.TrimSpace(mapping.From)
|
||||
to := strings.TrimSpace(mapping.To)
|
||||
if oldTo, exists := oldMap[from]; !exists || oldTo != to {
|
||||
if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -308,6 +334,66 @@ func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) b
|
||||
return oldKey != newKey
|
||||
}
|
||||
|
||||
// hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings.
|
||||
func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool {
|
||||
if old == nil {
|
||||
return len(new.UpstreamAPIKeys) > 0
|
||||
}
|
||||
|
||||
if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Build map for comparison: upstreamKey -> set of clientKeys
|
||||
type entryInfo struct {
|
||||
upstreamKey string
|
||||
clientKeys map[string]struct{}
|
||||
}
|
||||
oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys))
|
||||
for i, entry := range old.UpstreamAPIKeys {
|
||||
clientKeys := make(map[string]struct{}, len(entry.APIKeys))
|
||||
for _, k := range entry.APIKeys {
|
||||
trimmed := strings.TrimSpace(k)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
clientKeys[trimmed] = struct{}{}
|
||||
}
|
||||
oldEntries[i] = entryInfo{
|
||||
upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey),
|
||||
clientKeys: clientKeys,
|
||||
}
|
||||
}
|
||||
|
||||
for i, newEntry := range new.UpstreamAPIKeys {
|
||||
if i >= len(oldEntries) {
|
||||
return true
|
||||
}
|
||||
oldE := oldEntries[i]
|
||||
if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey {
|
||||
return true
|
||||
}
|
||||
newKeys := make(map[string]struct{}, len(newEntry.APIKeys))
|
||||
for _, k := range newEntry.APIKeys {
|
||||
trimmed := strings.TrimSpace(k)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
newKeys[trimmed] = struct{}{}
|
||||
}
|
||||
if len(newKeys) != len(oldE.clientKeys) {
|
||||
return true
|
||||
}
|
||||
for k := range newKeys {
|
||||
if _, ok := oldE.clientKeys[k]; !ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetModelMapper returns the model mapper instance (for testing/debugging).
|
||||
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
|
||||
return m.modelMapper
|
||||
|
||||
@@ -146,6 +146,9 @@ func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
|
||||
m := &AmpModule{enabled: true}
|
||||
ms := NewMultiSourceSecretWithPath("", p, time.Minute)
|
||||
m.secretSource = ms
|
||||
m.lastConfig = &config.AmpCode{
|
||||
UpstreamAPIKey: "old-key",
|
||||
}
|
||||
|
||||
// Warm the cache
|
||||
if _, err := ms.Get(context.Background()); err != nil {
|
||||
@@ -157,7 +160,7 @@ func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update config - should invalidate cache
|
||||
if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x"}}); err != nil {
|
||||
if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x", UpstreamAPIKey: "new-key"}}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -309,3 +312,41 @@ func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_hasUpstreamAPIKeysChanged_DetectsRemovedKeyWithDuplicateInput(t *testing.T) {
|
||||
m := &AmpModule{}
|
||||
|
||||
oldCfg := &config.AmpCode{
|
||||
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
|
||||
},
|
||||
}
|
||||
newCfg := &config.AmpCode{
|
||||
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k1"}},
|
||||
},
|
||||
}
|
||||
|
||||
if !m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
|
||||
t.Fatal("expected change to be detected when k2 is removed but new list contains duplicates")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_hasUpstreamAPIKeysChanged_IgnoresEmptyAndWhitespaceKeys(t *testing.T) {
|
||||
m := &AmpModule{}
|
||||
|
||||
oldCfg := &config.AmpCode{
|
||||
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
|
||||
},
|
||||
}
|
||||
newCfg := &config.AmpCode{
|
||||
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||
{UpstreamAPIKey: "u1", APIKeys: []string{" k1 ", "", "k2", " "}},
|
||||
},
|
||||
}
|
||||
|
||||
if m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
|
||||
t.Fatal("expected no change when only whitespace/empty entries differ")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -64,7 +65,7 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
|
||||
fields["cost"] = "amp_credits"
|
||||
fields["source"] = "ampcode.com"
|
||||
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
||||
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
||||
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local provider, add to config: ampcode.model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
||||
|
||||
case RouteTypeNoProvider:
|
||||
fields["cost"] = "none"
|
||||
@@ -134,7 +135,44 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
}
|
||||
|
||||
// Normalize model (handles dynamic thinking suffixes)
|
||||
normalizedModel, _ := util.NormalizeThinkingModel(modelName)
|
||||
suffixResult := thinking.ParseSuffix(modelName)
|
||||
normalizedModel := suffixResult.ModelName
|
||||
thinkingSuffix := ""
|
||||
if suffixResult.HasSuffix {
|
||||
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
|
||||
}
|
||||
|
||||
resolveMappedModel := func() (string, []string) {
|
||||
if fh.modelMapper == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
mappedModel := fh.modelMapper.MapModel(modelName)
|
||||
if mappedModel == "" {
|
||||
mappedModel = fh.modelMapper.MapModel(normalizedModel)
|
||||
}
|
||||
mappedModel = strings.TrimSpace(mappedModel)
|
||||
if mappedModel == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
|
||||
// already specifies its own thinking suffix.
|
||||
if thinkingSuffix != "" {
|
||||
mappedSuffixResult := thinking.ParseSuffix(mappedModel)
|
||||
if !mappedSuffixResult.HasSuffix {
|
||||
mappedModel += thinkingSuffix
|
||||
}
|
||||
}
|
||||
|
||||
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
|
||||
mappedProviders := util.GetProviderName(mappedBaseModel)
|
||||
if len(mappedProviders) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return mappedModel, mappedProviders
|
||||
}
|
||||
|
||||
// Track resolved model for logging (may change if mapping is applied)
|
||||
resolvedModel := normalizedModel
|
||||
@@ -147,21 +185,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
if forceMappings {
|
||||
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
||||
// This allows users to route Amp requests to their preferred OAuth providers
|
||||
if fh.modelMapper != nil {
|
||||
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
||||
// Mapping found - check if we have a provider for the mapped model
|
||||
mappedProviders := util.GetProviderName(mappedModel)
|
||||
if len(mappedProviders) > 0 {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
}
|
||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
|
||||
// If no mapping applied, check for local providers
|
||||
@@ -174,21 +206,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
|
||||
if len(providers) == 0 {
|
||||
// No providers configured - check if we have a model mapping
|
||||
if fh.modelMapper != nil {
|
||||
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
||||
// Mapping found - check if we have a provider for the mapped model
|
||||
mappedProviders := util.GetProviderName(mappedModel)
|
||||
if len(mappedProviders) > 0 {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
}
|
||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -222,14 +248,14 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
// Log: Model was mapped to another model
|
||||
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
||||
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
||||
rewriter := NewResponseRewriter(c.Writer, normalizedModel)
|
||||
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||
c.Writer = rewriter
|
||||
// Filter Anthropic-Beta header only for local handling paths
|
||||
filterAntropicBetaHeader(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
handler(c)
|
||||
rewriter.Flush()
|
||||
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, normalizedModel)
|
||||
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName)
|
||||
} else if len(providers) > 0 {
|
||||
// Log: Using local provider (free)
|
||||
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
||||
|
||||
73
internal/api/modules/amp/fallback_handlers_test.go
Normal file
73
internal/api/modules/amp/fallback_handlers_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{
|
||||
{ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-amp-fallback")
|
||||
|
||||
mapper := NewModelMapper([]config.AmpModelMapping{
|
||||
{From: "gpt-5.2", To: "test/gpt-5.2"},
|
||||
})
|
||||
|
||||
fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil)
|
||||
|
||||
handler := func(c *gin.Context) {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"model": req.Model,
|
||||
"seen_model": req.Model,
|
||||
})
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/chat/completions", fallback.WrapHandler(handler))
|
||||
|
||||
reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Model string `json:"model"`
|
||||
SeenModel string `json:"seen_model"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Failed to parse response JSON: %v", err)
|
||||
}
|
||||
|
||||
if resp.Model != "gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model)
|
||||
}
|
||||
if resp.SeenModel != "test/gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel)
|
||||
}
|
||||
}
|
||||
@@ -3,10 +3,12 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -26,13 +28,15 @@ type ModelMapper interface {
|
||||
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
|
||||
type DefaultModelMapper struct {
|
||||
mu sync.RWMutex
|
||||
mappings map[string]string // from -> to (normalized lowercase keys)
|
||||
mappings map[string]string // exact: from -> to (normalized lowercase keys)
|
||||
regexps []regexMapping // regex rules evaluated in order
|
||||
}
|
||||
|
||||
// NewModelMapper creates a new model mapper with the given initial mappings.
|
||||
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
||||
m := &DefaultModelMapper{
|
||||
mappings: make(map[string]string),
|
||||
regexps: nil,
|
||||
}
|
||||
m.UpdateMappings(mappings)
|
||||
return m
|
||||
@@ -41,6 +45,11 @@ func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
||||
// MapModel checks if a mapping exists for the requested model and if the
|
||||
// target model has available local providers. Returns the mapped model name
|
||||
// or empty string if no valid mapping exists.
|
||||
//
|
||||
// If the requested model contains a thinking suffix (e.g., "g25p(8192)"),
|
||||
// the suffix is preserved in the returned model name (e.g., "gemini-2.5-pro(8192)").
|
||||
// However, if the mapping target already contains a suffix, the config suffix
|
||||
// takes priority over the user's suffix.
|
||||
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||
if requestedModel == "" {
|
||||
return ""
|
||||
@@ -49,22 +58,52 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
// Normalize the requested model for lookup
|
||||
normalizedRequest := strings.ToLower(strings.TrimSpace(requestedModel))
|
||||
// Extract thinking suffix from requested model using ParseSuffix
|
||||
requestResult := thinking.ParseSuffix(requestedModel)
|
||||
baseModel := requestResult.ModelName
|
||||
|
||||
// Check for direct mapping
|
||||
targetModel, exists := m.mappings[normalizedRequest]
|
||||
// Normalize the base model for lookup (case-insensitive)
|
||||
normalizedBase := strings.ToLower(strings.TrimSpace(baseModel))
|
||||
|
||||
// Check for direct mapping using base model name
|
||||
targetModel, exists := m.mappings[normalizedBase]
|
||||
if !exists {
|
||||
return ""
|
||||
// Try regex mappings in order using base model only
|
||||
// (suffix is handled separately via ParseSuffix)
|
||||
for _, rm := range m.regexps {
|
||||
if rm.re.MatchString(baseModel) {
|
||||
targetModel = rm.to
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// Verify target model has available providers
|
||||
providers := util.GetProviderName(targetModel)
|
||||
// Check if target model already has a thinking suffix (config priority)
|
||||
targetResult := thinking.ParseSuffix(targetModel)
|
||||
|
||||
// Verify target model has available providers (use base model for lookup)
|
||||
providers := util.GetProviderName(targetResult.ModelName)
|
||||
if len(providers) == 0 {
|
||||
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
|
||||
return ""
|
||||
}
|
||||
|
||||
// Suffix handling: config suffix takes priority, otherwise preserve user suffix
|
||||
if targetResult.HasSuffix {
|
||||
// Config's "to" already contains a suffix - use it as-is (config priority)
|
||||
return targetModel
|
||||
}
|
||||
|
||||
// Preserve user's thinking suffix on the mapped model
|
||||
// (skip empty suffixes to avoid returning "model()")
|
||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||
return targetModel + "(" + requestResult.RawSuffix + ")"
|
||||
}
|
||||
|
||||
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
|
||||
return targetModel
|
||||
}
|
||||
@@ -77,6 +116,7 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
|
||||
|
||||
// Clear and rebuild mappings
|
||||
m.mappings = make(map[string]string, len(mappings))
|
||||
m.regexps = make([]regexMapping, 0, len(mappings))
|
||||
|
||||
for _, mapping := range mappings {
|
||||
from := strings.TrimSpace(mapping.From)
|
||||
@@ -87,16 +127,30 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Store with normalized lowercase key for case-insensitive lookup
|
||||
normalizedFrom := strings.ToLower(from)
|
||||
m.mappings[normalizedFrom] = to
|
||||
|
||||
log.Debugf("amp model mapping registered: %s -> %s", from, to)
|
||||
if mapping.Regex {
|
||||
// Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups
|
||||
pattern := "(?i)" + from
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
log.Warnf("amp model mapping: invalid regex %q: %v", from, err)
|
||||
continue
|
||||
}
|
||||
m.regexps = append(m.regexps, regexMapping{re: re, to: to})
|
||||
log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to)
|
||||
} else {
|
||||
// Store with normalized lowercase key for case-insensitive lookup
|
||||
normalizedFrom := strings.ToLower(from)
|
||||
m.mappings[normalizedFrom] = to
|
||||
log.Debugf("amp model mapping registered: %s -> %s", from, to)
|
||||
}
|
||||
}
|
||||
|
||||
if len(m.mappings) > 0 {
|
||||
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
|
||||
}
|
||||
if n := len(m.regexps); n > 0 {
|
||||
log.Infof("amp model mapping: loaded %d regex mapping(s)", n)
|
||||
}
|
||||
}
|
||||
|
||||
// GetMappings returns a copy of current mappings (for debugging/status).
|
||||
@@ -110,3 +164,8 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type regexMapping struct {
|
||||
re *regexp.Regexp
|
||||
to string
|
||||
}
|
||||
|
||||
@@ -71,6 +71,25 @@ func TestModelMapper_MapModel_WithProvider(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_MapModel_TargetWithThinkingSuffix(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-thinking", "codex", []*registry.ModelInfo{
|
||||
{ID: "gpt-5.2", OwnedBy: "openai", Type: "codex"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-thinking")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "gpt-5.2-alias", To: "gpt-5.2(xhigh)"},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
result := mapper.MapModel("gpt-5.2-alias")
|
||||
if result != "gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected gpt-5.2(xhigh), got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{
|
||||
@@ -184,3 +203,173 @@ func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) {
|
||||
t.Error("Original map was modified")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-regex-1", "gemini", []*registry.ModelInfo{
|
||||
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-regex-1")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "^gpt-5$", To: "gemini-2.5-pro", Regex: true},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
// Incoming model has reasoning suffix, regex matches base, suffix is preserved
|
||||
result := mapper.MapModel("gpt-5(high)")
|
||||
if result != "gemini-2.5-pro(high)" {
|
||||
t.Errorf("Expected gemini-2.5-pro(high), got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_Regex_ExactPrecedence(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-regex-2", "claude", []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
||||
})
|
||||
reg.RegisterClient("test-client-regex-3", "gemini", []*registry.ModelInfo{
|
||||
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-regex-2")
|
||||
defer reg.UnregisterClient("test-client-regex-3")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "gpt-5", To: "claude-sonnet-4"}, // exact
|
||||
{From: "^gpt-5.*$", To: "gemini-2.5-pro", Regex: true}, // regex
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
// Exact match should win over regex
|
||||
result := mapper.MapModel("gpt-5")
|
||||
if result != "claude-sonnet-4" {
|
||||
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_Regex_InvalidPattern_Skipped(t *testing.T) {
|
||||
// Invalid regex should be skipped and not cause panic
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "(", To: "target", Regex: true},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
result := mapper.MapModel("anything")
|
||||
if result != "" {
|
||||
t.Errorf("Expected empty result due to invalid regex, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_Regex_CaseInsensitive(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-regex-4", "claude", []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-regex-4")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "^CLAUDE-OPUS-.*$", To: "claude-sonnet-4", Regex: true},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
result := mapper.MapModel("claude-opus-4.5")
|
||||
if result != "claude-sonnet-4" {
|
||||
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_SuffixPreservation(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
|
||||
// Register test models
|
||||
reg.RegisterClient("test-client-suffix", "gemini", []*registry.ModelInfo{
|
||||
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
|
||||
})
|
||||
reg.RegisterClient("test-client-suffix-2", "claude", []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-suffix")
|
||||
defer reg.UnregisterClient("test-client-suffix-2")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mappings []config.AmpModelMapping
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "numeric suffix preserved",
|
||||
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||
input: "g25p(8192)",
|
||||
want: "gemini-2.5-pro(8192)",
|
||||
},
|
||||
{
|
||||
name: "level suffix preserved",
|
||||
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||
input: "g25p(high)",
|
||||
want: "gemini-2.5-pro(high)",
|
||||
},
|
||||
{
|
||||
name: "no suffix unchanged",
|
||||
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||
input: "g25p",
|
||||
want: "gemini-2.5-pro",
|
||||
},
|
||||
{
|
||||
name: "config suffix takes priority",
|
||||
mappings: []config.AmpModelMapping{{From: "alias", To: "gemini-2.5-pro(medium)"}},
|
||||
input: "alias(high)",
|
||||
want: "gemini-2.5-pro(medium)",
|
||||
},
|
||||
{
|
||||
name: "regex with suffix preserved",
|
||||
mappings: []config.AmpModelMapping{{From: "^g25.*", To: "gemini-2.5-pro", Regex: true}},
|
||||
input: "g25p(8192)",
|
||||
want: "gemini-2.5-pro(8192)",
|
||||
},
|
||||
{
|
||||
name: "auto suffix preserved",
|
||||
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||
input: "g25p(auto)",
|
||||
want: "gemini-2.5-pro(auto)",
|
||||
},
|
||||
{
|
||||
name: "none suffix preserved",
|
||||
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||
input: "g25p(none)",
|
||||
want: "gemini-2.5-pro(none)",
|
||||
},
|
||||
{
|
||||
name: "case insensitive base lookup with suffix",
|
||||
mappings: []config.AmpModelMapping{{From: "G25P", To: "gemini-2.5-pro"}},
|
||||
input: "g25p(high)",
|
||||
want: "gemini-2.5-pro(high)",
|
||||
},
|
||||
{
|
||||
name: "empty suffix filtered out",
|
||||
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||
input: "g25p()",
|
||||
want: "gemini-2.5-pro",
|
||||
},
|
||||
{
|
||||
name: "incomplete suffix treated as no suffix",
|
||||
mappings: []config.AmpModelMapping{{From: "g25p(high", To: "gemini-2.5-pro"}},
|
||||
input: "g25p(high",
|
||||
want: "gemini-2.5-pro",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mapper := NewModelMapper(tt.mappings)
|
||||
got := mapper.MapModel(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("MapModel(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,33 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func removeQueryValuesMatching(req *http.Request, key string, match string) {
|
||||
if req == nil || req.URL == nil || match == "" {
|
||||
return
|
||||
}
|
||||
|
||||
q := req.URL.Query()
|
||||
values, ok := q[key]
|
||||
if !ok || len(values) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
kept := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
if v == match {
|
||||
continue
|
||||
}
|
||||
kept = append(kept, v)
|
||||
}
|
||||
|
||||
if len(kept) == 0 {
|
||||
q.Del(key)
|
||||
} else {
|
||||
q[key] = kept
|
||||
}
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
// readCloser wraps a reader and forwards Close to a separate closer.
|
||||
// Used to restore peeked bytes while preserving upstream body Close behavior.
|
||||
type readCloser struct {
|
||||
@@ -44,6 +71,19 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
originalDirector(req)
|
||||
req.Host = parsed.Host
|
||||
|
||||
// Remove client's Authorization header - it was only used for CLI Proxy API authentication
|
||||
// We will set our own Authorization using the configured upstream-api-key
|
||||
req.Header.Del("Authorization")
|
||||
req.Header.Del("X-Api-Key")
|
||||
req.Header.Del("X-Goog-Api-Key")
|
||||
|
||||
// Remove query-based credentials if they match the authenticated client API key.
|
||||
// This prevents leaking client auth material to the Amp upstream while avoiding
|
||||
// breaking unrelated upstream query parameters.
|
||||
clientKey := getClientAPIKeyFromContext(req.Context())
|
||||
removeQueryValuesMatching(req, "key", clientKey)
|
||||
removeQueryValuesMatching(req, "auth_token", clientKey)
|
||||
|
||||
// Preserve correlation headers for debugging
|
||||
if req.Header.Get("X-Request-ID") == "" {
|
||||
// Could generate one here if needed
|
||||
@@ -53,7 +93,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
// Users going through ampcode.com proxy are paying for the service and should get all features
|
||||
// including 1M context window (context-1m-2025-08-07)
|
||||
|
||||
// Inject API key from secret source (precedence: config > env > file)
|
||||
// Inject API key from secret source (only uses upstream-api-key from config)
|
||||
if key, err := secretSource.Get(req.Context()); err == nil && key != "" {
|
||||
req.Header.Set("X-Api-Key", key)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||
|
||||
@@ -3,11 +3,15 @@ package amp
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// Helper: compress data with gzip
|
||||
@@ -306,6 +310,159 @@ func TestReverseProxy_EmptySecret(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) {
|
||||
type captured struct {
|
||||
headers http.Header
|
||||
query string
|
||||
}
|
||||
got := make(chan captured, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery}
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simulate clientAPIKeyMiddleware injection (per-request)
|
||||
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key")
|
||||
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer client-key")
|
||||
req.Header.Set("X-Api-Key", "client-key")
|
||||
req.Header.Set("X-Goog-Api-Key", "client-key")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
c := <-got
|
||||
|
||||
// These are client-provided credentials and must not reach the upstream.
|
||||
if v := c.headers.Get("X-Goog-Api-Key"); v != "" {
|
||||
t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v)
|
||||
}
|
||||
|
||||
// We inject upstream Authorization/X-Api-Key, so the client auth must not survive.
|
||||
if v := c.headers.Get("Authorization"); v != "Bearer upstream" {
|
||||
t.Fatalf("Authorization should be upstream-injected, got: %q", v)
|
||||
}
|
||||
if v := c.headers.Get("X-Api-Key"); v != "upstream" {
|
||||
t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v)
|
||||
}
|
||||
|
||||
// Query-based credentials should be stripped only when they match the authenticated client key.
|
||||
// Should keep unrelated values and parameters.
|
||||
if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") {
|
||||
t.Fatalf("query credentials should be stripped, got raw query: %q", c.query)
|
||||
}
|
||||
if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") {
|
||||
t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) {
|
||||
gotHeaders := make(chan http.Header, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotHeaders <- r.Header.Clone()
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
mapped := NewMappedSecretSource(defaultSource)
|
||||
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, mapped)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simulate clientAPIKeyMiddleware injection (per-request)
|
||||
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1")
|
||||
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
hdr := <-gotHeaders
|
||||
if hdr.Get("X-Api-Key") != "u1" {
|
||||
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
|
||||
}
|
||||
if hdr.Get("Authorization") != "Bearer u1" {
|
||||
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) {
|
||||
gotHeaders := make(chan http.Header, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotHeaders <- r.Header.Clone()
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
mapped := NewMappedSecretSource(defaultSource)
|
||||
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, mapped)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2")
|
||||
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
hdr := <-gotHeaders
|
||||
if hdr.Get("X-Api-Key") != "default" {
|
||||
t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key"))
|
||||
}
|
||||
if hdr.Get("Authorization") != "Bearer default" {
|
||||
t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_ErrorHandler(t *testing.T) {
|
||||
// Point proxy to a non-routable address to trigger error
|
||||
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))
|
||||
|
||||
@@ -125,7 +125,30 @@ func (rw *ResponseRewriter) Flush() {
|
||||
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
||||
|
||||
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
||||
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
|
||||
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
|
||||
// The Amp client struggles when both thinking and tool_use blocks are present
|
||||
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
|
||||
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
|
||||
if filtered.Exists() {
|
||||
originalCount := gjson.GetBytes(data, "content.#").Int()
|
||||
filteredCount := filtered.Get("#").Int()
|
||||
|
||||
if originalCount > filteredCount {
|
||||
var err error
|
||||
data, err = sjson.SetBytes(data, "content", filtered.Value())
|
||||
if err != nil {
|
||||
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
|
||||
} else {
|
||||
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
|
||||
// Log the result for verification
|
||||
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rw.originalModel == "" {
|
||||
return data
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -16,6 +17,37 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// clientAPIKeyContextKey is the context key used to pass the client API key
|
||||
// from gin.Context to the request context for SecretSource lookup.
|
||||
type clientAPIKeyContextKey struct{}
|
||||
|
||||
// clientAPIKeyMiddleware injects the authenticated client API key from gin.Context["apiKey"]
|
||||
// into the request context so that SecretSource can look it up for per-client upstream routing.
|
||||
func clientAPIKeyMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Extract the client API key from gin context (set by AuthMiddleware)
|
||||
if apiKey, exists := c.Get("apiKey"); exists {
|
||||
if keyStr, ok := apiKey.(string); ok && keyStr != "" {
|
||||
// Inject into request context for SecretSource.Get(ctx) to read
|
||||
ctx := context.WithValue(c.Request.Context(), clientAPIKeyContextKey{}, keyStr)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// getClientAPIKeyFromContext retrieves the client API key from request context.
|
||||
// Returns empty string if not present.
|
||||
func getClientAPIKeyFromContext(ctx context.Context) string {
|
||||
if val := ctx.Value(clientAPIKeyContextKey{}); val != nil {
|
||||
if keyStr, ok := val.(string); ok {
|
||||
return keyStr
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's
|
||||
// localhost restriction setting. This allows hot-reload of the restriction without restarting.
|
||||
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
|
||||
@@ -95,10 +127,25 @@ func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// wrapManagementAuth skips auth for selected management paths while keeping authentication elsewhere.
|
||||
func wrapManagementAuth(auth gin.HandlerFunc, prefixes ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
for _, prefix := range prefixes {
|
||||
if strings.HasPrefix(path, prefix) && (len(path) == len(prefix) || path[len(prefix)] == '/') {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
auth(c)
|
||||
}
|
||||
}
|
||||
|
||||
// registerManagementRoutes registers Amp management proxy routes
|
||||
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
|
||||
// Uses dynamic middleware and proxy getter for hot-reload support.
|
||||
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler) {
|
||||
// The auth middleware validates Authorization header against configured API keys.
|
||||
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, auth gin.HandlerFunc) {
|
||||
ampAPI := engine.Group("/api")
|
||||
|
||||
// Always disable CORS for management routes to prevent browser-based attacks
|
||||
@@ -107,10 +154,16 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
// Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost())
|
||||
ampAPI.Use(m.localhostOnlyMiddleware())
|
||||
|
||||
if !m.IsRestrictedToLocalhost() {
|
||||
log.Warn("amp management routes are NOT restricted to localhost - this is insecure!")
|
||||
// Apply authentication middleware - requires valid API key in Authorization header
|
||||
var authWithBypass gin.HandlerFunc
|
||||
if auth != nil {
|
||||
ampAPI.Use(auth)
|
||||
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
|
||||
}
|
||||
|
||||
// Inject client API key into request context for per-client upstream routing
|
||||
ampAPI.Use(clientAPIKeyMiddleware())
|
||||
|
||||
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
||||
proxyHandler := func(c *gin.Context) {
|
||||
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
|
||||
@@ -154,7 +207,18 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
// Root-level routes that AMP CLI expects without /api prefix
|
||||
// These need the same security middleware as the /api/* routes (dynamic for hot-reload)
|
||||
rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()}
|
||||
if authWithBypass != nil {
|
||||
rootMiddleware = append(rootMiddleware, authWithBypass)
|
||||
}
|
||||
// Add clientAPIKeyMiddleware after auth for per-client upstream routing
|
||||
rootMiddleware = append(rootMiddleware, clientAPIKeyMiddleware())
|
||||
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/docs/*path", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/settings", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/settings/*path", append(rootMiddleware, proxyHandler)...)
|
||||
|
||||
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...)
|
||||
|
||||
@@ -217,6 +281,8 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
if auth != nil {
|
||||
ampProviders.Use(auth)
|
||||
}
|
||||
// Inject client API key into request context for per-client upstream routing
|
||||
ampProviders.Use(clientAPIKeyMiddleware())
|
||||
|
||||
provider := ampProviders.Group("/:provider")
|
||||
|
||||
@@ -262,7 +328,7 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
v1betaAmp := provider.Group("/v1beta")
|
||||
{
|
||||
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1betaAmp.POST("/models/:action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||
v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,9 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
m.setProxy(proxy)
|
||||
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
m.registerManagementRoutes(r, base)
|
||||
m.registerManagementRoutes(r, base, nil)
|
||||
srv := httptest.NewServer(r)
|
||||
defer srv.Close()
|
||||
|
||||
managementPaths := []struct {
|
||||
path string
|
||||
@@ -63,11 +65,17 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
for _, path := range managementPaths {
|
||||
t.Run(path.path, func(t *testing.T) {
|
||||
proxyCalled = false
|
||||
req := httptest.NewRequest(path.method, path.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
req, err := http.NewRequest(path.method, srv.URL+path.path, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build request: %v", err)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if w.Code == http.StatusNotFound {
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
t.Fatalf("route %s not registered", path.path)
|
||||
}
|
||||
if !proxyCalled {
|
||||
|
||||
@@ -9,6 +9,9 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// SecretSource provides Amp API keys with configurable precedence and caching
|
||||
@@ -164,3 +167,82 @@ func NewStaticSecretSource(key string) *StaticSecretSource {
|
||||
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
|
||||
return s.key, nil
|
||||
}
|
||||
|
||||
// MappedSecretSource wraps a default SecretSource and adds per-client API key mapping.
|
||||
// When a request context contains a client API key that matches a configured mapping,
|
||||
// the corresponding upstream key is returned. Otherwise, falls back to the default source.
|
||||
type MappedSecretSource struct {
|
||||
defaultSource SecretSource
|
||||
mu sync.RWMutex
|
||||
lookup map[string]string // clientKey -> upstreamKey
|
||||
}
|
||||
|
||||
// NewMappedSecretSource creates a MappedSecretSource wrapping the given default source.
|
||||
func NewMappedSecretSource(defaultSource SecretSource) *MappedSecretSource {
|
||||
return &MappedSecretSource{
|
||||
defaultSource: defaultSource,
|
||||
lookup: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves the Amp API key, checking per-client mappings first.
|
||||
// If the request context contains a client API key that matches a configured mapping,
|
||||
// returns the corresponding upstream key. Otherwise, falls back to the default source.
|
||||
func (s *MappedSecretSource) Get(ctx context.Context) (string, error) {
|
||||
// Try to get client API key from request context
|
||||
clientKey := getClientAPIKeyFromContext(ctx)
|
||||
if clientKey != "" {
|
||||
s.mu.RLock()
|
||||
if upstreamKey, ok := s.lookup[clientKey]; ok && upstreamKey != "" {
|
||||
s.mu.RUnlock()
|
||||
return upstreamKey, nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
}
|
||||
|
||||
// Fall back to default source
|
||||
return s.defaultSource.Get(ctx)
|
||||
}
|
||||
|
||||
// UpdateMappings rebuilds the client-to-upstream key mapping from configuration entries.
|
||||
// If the same client key appears in multiple entries, logs a warning and uses the first one.
|
||||
func (s *MappedSecretSource) UpdateMappings(entries []config.AmpUpstreamAPIKeyEntry) {
|
||||
newLookup := make(map[string]string)
|
||||
|
||||
for _, entry := range entries {
|
||||
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
|
||||
if upstreamKey == "" {
|
||||
continue
|
||||
}
|
||||
for _, clientKey := range entry.APIKeys {
|
||||
trimmedKey := strings.TrimSpace(clientKey)
|
||||
if trimmedKey == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := newLookup[trimmedKey]; exists {
|
||||
// Log warning for duplicate client key, first one wins
|
||||
log.Warnf("amp upstream-api-keys: client API key appears in multiple entries; using first mapping.")
|
||||
continue
|
||||
}
|
||||
newLookup[trimmedKey] = upstreamKey
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.lookup = newLookup
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// UpdateDefaultExplicitKey updates the explicit key on the underlying MultiSourceSecret (if applicable).
|
||||
func (s *MappedSecretSource) UpdateDefaultExplicitKey(key string) {
|
||||
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
|
||||
ms.UpdateExplicitKey(key)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateCache invalidates cache on the underlying MultiSourceSecret (if applicable).
|
||||
func (s *MappedSecretSource) InvalidateCache() {
|
||||
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
|
||||
ms.InvalidateCache()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,10 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/sirupsen/logrus/hooks/test"
|
||||
)
|
||||
|
||||
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
|
||||
@@ -278,3 +282,85 @@ func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) {
|
||||
t.Fatalf("after cache expiry, expected new-value, got %q", got3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) {
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
s := NewMappedSecretSource(defaultSource)
|
||||
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "u1" {
|
||||
t.Fatalf("want u1, got %q", got)
|
||||
}
|
||||
|
||||
ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2")
|
||||
got, err = s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "default" {
|
||||
t.Fatalf("want default fallback, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) {
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
s := NewMappedSecretSource(defaultSource)
|
||||
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
{
|
||||
UpstreamAPIKey: "u2",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "u1" {
|
||||
t.Fatalf("want u1 (first wins), got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) {
|
||||
hook := test.NewLocal(log.StandardLogger())
|
||||
defer hook.Reset()
|
||||
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
s := NewMappedSecretSource(defaultSource)
|
||||
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
{
|
||||
UpstreamAPIKey: "u2",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
foundWarning := false
|
||||
for _, entry := range hook.AllEntries() {
|
||||
if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." {
|
||||
foundWarning = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundWarning {
|
||||
t.Fatal("expected warning log for duplicate client key, but none was found")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,9 +23,11 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
||||
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
@@ -33,6 +35,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -209,13 +212,15 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
// Resolve logs directory relative to the configuration file directory.
|
||||
var requestLogger logging.RequestLogger
|
||||
var toggle func(bool)
|
||||
if optionState.requestLoggerFactory != nil {
|
||||
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
|
||||
}
|
||||
if requestLogger != nil {
|
||||
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
|
||||
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
|
||||
toggle = setter.SetEnabled
|
||||
if !cfg.CommercialMode {
|
||||
if optionState.requestLoggerFactory != nil {
|
||||
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
|
||||
}
|
||||
if requestLogger != nil {
|
||||
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
|
||||
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
|
||||
toggle = setter.SetEnabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,13 +235,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
|
||||
|
||||
// Create server instance
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s := &Server{
|
||||
engine: engine,
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames),
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager),
|
||||
cfg: cfg,
|
||||
accessManager: accessManager,
|
||||
requestLogger: requestLogger,
|
||||
@@ -255,15 +256,13 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
}
|
||||
managementasset.SetCurrentConfig(cfg)
|
||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
misc.SetCodexInstructionsEnabled(cfg.CodexInstructionsEnabled)
|
||||
// Initialize management handler
|
||||
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
|
||||
if optionState.localPassword != "" {
|
||||
s.mgmt.SetLocalPassword(optionState.localPassword)
|
||||
}
|
||||
logDir := filepath.Join(s.currentPath, "logs")
|
||||
if base := util.WritablePath(); base != "" {
|
||||
logDir = filepath.Join(base, "logs")
|
||||
}
|
||||
logDir := logging.ResolveLogDirectory(cfg)
|
||||
s.mgmt.SetLogDirectory(logDir)
|
||||
s.localPassword = optionState.localPassword
|
||||
|
||||
@@ -294,6 +293,11 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
s.registerManagementRoutes()
|
||||
}
|
||||
|
||||
// === CLIProxyAPIPlus 扩展: 注册 Kiro OAuth Web 路由 ===
|
||||
kiroOAuthHandler := kiro.NewOAuthWebHandler(cfg)
|
||||
kiroOAuthHandler.RegisterRoutes(engine)
|
||||
log.Info("Kiro OAuth Web routes registered at /v0/oauth/kiro/*")
|
||||
|
||||
if optionState.keepAliveEnabled {
|
||||
s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout)
|
||||
}
|
||||
@@ -334,8 +338,8 @@ func (s *Server) setupRoutes() {
|
||||
v1beta.Use(AuthMiddleware(s.accessManager))
|
||||
{
|
||||
v1beta.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1beta.POST("/models/:action", geminiHandlers.GeminiHandler)
|
||||
v1beta.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||
v1beta.POST("/models/*action", geminiHandlers.GeminiHandler)
|
||||
v1beta.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
|
||||
// Root endpoint
|
||||
@@ -364,10 +368,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
// Persist to a temporary file keyed by state
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-anthropic-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "anthropic", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -377,9 +382,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-codex-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "codex", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -389,9 +396,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-gemini-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gemini", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -401,9 +410,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-iflow-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -413,9 +424,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-antigravity-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "antigravity", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -425,9 +438,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-kiro-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "kiro", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -487,6 +502,8 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
|
||||
{
|
||||
mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
|
||||
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
|
||||
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
|
||||
mgmt.GET("/config", s.mgmt.GetConfig)
|
||||
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
|
||||
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
|
||||
@@ -500,6 +517,10 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PUT("/logging-to-file", s.mgmt.PutLoggingToFile)
|
||||
mgmt.PATCH("/logging-to-file", s.mgmt.PutLoggingToFile)
|
||||
|
||||
mgmt.GET("/logs-max-total-size-mb", s.mgmt.GetLogsMaxTotalSizeMB)
|
||||
mgmt.PUT("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
|
||||
mgmt.PATCH("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
|
||||
|
||||
mgmt.GET("/usage-statistics-enabled", s.mgmt.GetUsageStatisticsEnabled)
|
||||
mgmt.PUT("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
||||
mgmt.PATCH("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
||||
@@ -509,6 +530,8 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL)
|
||||
mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL)
|
||||
|
||||
mgmt.POST("/api-call", s.mgmt.APICall)
|
||||
|
||||
mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject)
|
||||
mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
|
||||
mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
|
||||
@@ -531,6 +554,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
|
||||
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
|
||||
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
|
||||
mgmt.GET("/request-log-by-id/:id", s.mgmt.GetRequestLogByID)
|
||||
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
|
||||
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
|
||||
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
|
||||
@@ -557,6 +581,10 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
|
||||
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
||||
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
||||
mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys)
|
||||
mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys)
|
||||
mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys)
|
||||
mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys)
|
||||
|
||||
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
|
||||
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
|
||||
@@ -565,6 +593,14 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
||||
mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
||||
|
||||
mgmt.GET("/force-model-prefix", s.mgmt.GetForceModelPrefix)
|
||||
mgmt.PUT("/force-model-prefix", s.mgmt.PutForceModelPrefix)
|
||||
mgmt.PATCH("/force-model-prefix", s.mgmt.PutForceModelPrefix)
|
||||
|
||||
mgmt.GET("/routing/strategy", s.mgmt.GetRoutingStrategy)
|
||||
mgmt.PUT("/routing/strategy", s.mgmt.PutRoutingStrategy)
|
||||
mgmt.PATCH("/routing/strategy", s.mgmt.PutRoutingStrategy)
|
||||
|
||||
mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys)
|
||||
mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys)
|
||||
mgmt.PATCH("/claude-api-key", s.mgmt.PatchClaudeKey)
|
||||
@@ -580,16 +616,27 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat)
|
||||
mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat)
|
||||
|
||||
mgmt.GET("/vertex-api-key", s.mgmt.GetVertexCompatKeys)
|
||||
mgmt.PUT("/vertex-api-key", s.mgmt.PutVertexCompatKeys)
|
||||
mgmt.PATCH("/vertex-api-key", s.mgmt.PatchVertexCompatKey)
|
||||
mgmt.DELETE("/vertex-api-key", s.mgmt.DeleteVertexCompatKey)
|
||||
|
||||
mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels)
|
||||
mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels)
|
||||
mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels)
|
||||
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
|
||||
|
||||
mgmt.GET("/oauth-model-alias", s.mgmt.GetOAuthModelAlias)
|
||||
mgmt.PUT("/oauth-model-alias", s.mgmt.PutOAuthModelAlias)
|
||||
mgmt.PATCH("/oauth-model-alias", s.mgmt.PatchOAuthModelAlias)
|
||||
mgmt.DELETE("/oauth-model-alias", s.mgmt.DeleteOAuthModelAlias)
|
||||
|
||||
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
|
||||
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
|
||||
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
||||
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
|
||||
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
||||
|
||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||
@@ -600,6 +647,8 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
||||
mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
|
||||
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||
}
|
||||
}
|
||||
@@ -628,7 +677,7 @@ func (s *Server) serveManagementControlPanel(c *gin.Context) {
|
||||
|
||||
if _, err := os.Stat(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
@@ -857,11 +906,20 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
}
|
||||
}
|
||||
|
||||
if oldCfg != nil && oldCfg.LoggingToFile != cfg.LoggingToFile {
|
||||
if err := logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
||||
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
||||
if err := logging.ConfigureLogOutput(cfg); err != nil {
|
||||
log.Errorf("failed to reconfigure log output: %v", err)
|
||||
} else {
|
||||
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile)
|
||||
if oldCfg == nil {
|
||||
log.Debug("log output configuration refreshed")
|
||||
} else {
|
||||
if oldCfg.LoggingToFile != cfg.LoggingToFile {
|
||||
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile)
|
||||
}
|
||||
if oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
||||
log.Debugf("logs_max_total_size_mb updated from %d to %d", oldCfg.LogsMaxTotalSizeMB, cfg.LogsMaxTotalSizeMB)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -882,6 +940,16 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
log.Debugf("disable_cooling toggled to %t", cfg.DisableCooling)
|
||||
}
|
||||
}
|
||||
|
||||
if oldCfg == nil || oldCfg.CodexInstructionsEnabled != cfg.CodexInstructionsEnabled {
|
||||
misc.SetCodexInstructionsEnabled(cfg.CodexInstructionsEnabled)
|
||||
if oldCfg != nil {
|
||||
log.Debugf("codex_instructions_enabled updated from %t to %t", oldCfg.CodexInstructionsEnabled, cfg.CodexInstructionsEnabled)
|
||||
} else {
|
||||
log.Debugf("codex_instructions_enabled toggled to %t", cfg.CodexInstructionsEnabled)
|
||||
}
|
||||
}
|
||||
|
||||
if s.handlers != nil && s.handlers.AuthManager != nil {
|
||||
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
||||
}
|
||||
@@ -938,17 +1006,11 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
// Save YAML snapshot for next comparison
|
||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s.handlers.SetOpenAICompatProviders(providerNames)
|
||||
|
||||
s.handlers.UpdateClients(&cfg.SDKConfig)
|
||||
|
||||
if !cfg.RemoteManagement.DisableControlPanel {
|
||||
staticDir := managementasset.StaticDir(s.configFilePath)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
}
|
||||
if s.mgmt != nil {
|
||||
s.mgmt.SetConfig(cfg)
|
||||
@@ -965,8 +1027,12 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
log.Warnf("amp module is nil, skipping config update")
|
||||
}
|
||||
|
||||
// Count client sources from configuration and auth directory
|
||||
authFiles := util.CountAuthFiles(cfg.AuthDir)
|
||||
// Count client sources from configuration and auth store.
|
||||
tokenStore := sdkAuth.GetTokenStore()
|
||||
if dirSetter, ok := tokenStore.(interface{ SetBaseDir(string) }); ok {
|
||||
dirSetter.SetBaseDir(cfg.AuthDir)
|
||||
}
|
||||
authEntries := util.CountAuthFiles(context.Background(), tokenStore)
|
||||
geminiAPIKeyCount := len(cfg.GeminiKey)
|
||||
claudeAPIKeyCount := len(cfg.ClaudeKey)
|
||||
codexAPIKeyCount := len(cfg.CodexKey)
|
||||
@@ -977,10 +1043,10 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
openAICompatCount += len(entry.APIKeyEntries)
|
||||
}
|
||||
|
||||
total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
|
||||
fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
|
||||
total := authEntries + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
|
||||
fmt.Printf("server clients and configuration updated: %d clients (%d auth entries + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
|
||||
total,
|
||||
authFiles,
|
||||
authEntries,
|
||||
geminiAPIKeyCount,
|
||||
claudeAPIKeyCount,
|
||||
codexAPIKeyCount,
|
||||
|
||||
46
internal/auth/codex/filename.go
Normal file
46
internal/auth/codex/filename.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package codex
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// CredentialFileName returns the filename used to persist Codex OAuth credentials.
|
||||
// When planType is available (e.g. "plus", "team"), it is appended after the email
|
||||
// as a suffix to disambiguate subscriptions.
|
||||
func CredentialFileName(email, planType, hashAccountID string, includeProviderPrefix bool) string {
|
||||
email = strings.TrimSpace(email)
|
||||
plan := normalizePlanTypeForFilename(planType)
|
||||
|
||||
prefix := ""
|
||||
if includeProviderPrefix {
|
||||
prefix = "codex"
|
||||
}
|
||||
|
||||
if plan == "" {
|
||||
return fmt.Sprintf("%s-%s.json", prefix, email)
|
||||
} else if plan == "team" {
|
||||
return fmt.Sprintf("%s-%s-%s-%s.json", prefix, hashAccountID, email, plan)
|
||||
}
|
||||
return fmt.Sprintf("%s-%s-%s.json", prefix, email, plan)
|
||||
}
|
||||
|
||||
func normalizePlanTypeForFilename(planType string) string {
|
||||
planType = strings.TrimSpace(planType)
|
||||
if planType == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := strings.FieldsFunc(planType, func(r rune) bool {
|
||||
return !unicode.IsLetter(r) && !unicode.IsDigit(r)
|
||||
})
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
for i, part := range parts {
|
||||
parts[i] = strings.ToLower(strings.TrimSpace(part))
|
||||
}
|
||||
return strings.Join(parts, "-")
|
||||
}
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -28,8 +29,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
geminiOauthClientSecret = "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"
|
||||
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
geminiOauthClientSecret = "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"
|
||||
geminiDefaultCallbackPort = 8085
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -46,6 +48,13 @@ var (
|
||||
type GeminiAuth struct {
|
||||
}
|
||||
|
||||
// WebLoginOptions customizes the interactive OAuth flow.
|
||||
type WebLoginOptions struct {
|
||||
NoBrowser bool
|
||||
CallbackPort int
|
||||
Prompt func(string) (string, error)
|
||||
}
|
||||
|
||||
// NewGeminiAuth creates a new instance of GeminiAuth.
|
||||
func NewGeminiAuth() *GeminiAuth {
|
||||
return &GeminiAuth{}
|
||||
@@ -59,12 +68,18 @@ func NewGeminiAuth() *GeminiAuth {
|
||||
// - ctx: The context for the HTTP client
|
||||
// - ts: The Gemini token storage containing authentication tokens
|
||||
// - cfg: The configuration containing proxy settings
|
||||
// - noBrowser: Optional parameter to disable browser opening
|
||||
// - opts: Optional parameters to customize browser and prompt behavior
|
||||
//
|
||||
// Returns:
|
||||
// - *http.Client: An HTTP client configured with authentication
|
||||
// - error: An error if the client configuration fails, nil otherwise
|
||||
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, noBrowser ...bool) (*http.Client, error) {
|
||||
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
|
||||
callbackPort := geminiDefaultCallbackPort
|
||||
if opts != nil && opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||
|
||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||
if err == nil {
|
||||
@@ -99,7 +114,7 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
||||
conf := &oauth2.Config{
|
||||
ClientID: geminiOauthClientID,
|
||||
ClientSecret: geminiOauthClientSecret,
|
||||
RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server.
|
||||
RedirectURL: callbackURL, // This will be used by the local server.
|
||||
Scopes: geminiOauthScopes,
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
@@ -109,7 +124,7 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
||||
// If no token is found in storage, initiate the web-based OAuth flow.
|
||||
if ts.Token == nil {
|
||||
fmt.Printf("Could not load token from file, starting OAuth flow.\n")
|
||||
token, err = g.getTokenFromWeb(ctx, conf, noBrowser...)
|
||||
token, err = g.getTokenFromWeb(ctx, conf, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token from web: %w", err)
|
||||
}
|
||||
@@ -205,35 +220,50 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
|
||||
// Parameters:
|
||||
// - ctx: The context for the HTTP client
|
||||
// - config: The OAuth2 configuration
|
||||
// - noBrowser: Optional parameter to disable browser opening
|
||||
// - opts: Optional parameters to customize browser and prompt behavior
|
||||
//
|
||||
// Returns:
|
||||
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
|
||||
// - error: An error if the token acquisition fails, nil otherwise
|
||||
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, noBrowser ...bool) (*oauth2.Token, error) {
|
||||
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
|
||||
callbackPort := geminiDefaultCallbackPort
|
||||
if opts != nil && opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||
|
||||
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
||||
codeChan := make(chan string)
|
||||
errChan := make(chan error)
|
||||
codeChan := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
// Create a new HTTP server with its own multiplexer.
|
||||
mux := http.NewServeMux()
|
||||
server := &http.Server{Addr: ":8085", Handler: mux}
|
||||
config.RedirectURL = "http://localhost:8085/oauth2callback"
|
||||
server := &http.Server{Addr: fmt.Sprintf(":%d", callbackPort), Handler: mux}
|
||||
config.RedirectURL = callbackURL
|
||||
|
||||
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.URL.Query().Get("error"); err != "" {
|
||||
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
|
||||
errChan <- fmt.Errorf("authentication failed via callback: %s", err)
|
||||
select {
|
||||
case errChan <- fmt.Errorf("authentication failed via callback: %s", err):
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
_, _ = fmt.Fprint(w, "Authentication failed: code not found.")
|
||||
errChan <- fmt.Errorf("code not found in callback")
|
||||
select {
|
||||
case errChan <- fmt.Errorf("code not found in callback"):
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
_, _ = fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
|
||||
codeChan <- code
|
||||
select {
|
||||
case codeChan <- code:
|
||||
default:
|
||||
}
|
||||
})
|
||||
|
||||
// Start the server in a goroutine.
|
||||
@@ -250,19 +280,24 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
||||
// Open the authorization URL in the user's browser.
|
||||
authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
||||
|
||||
if len(noBrowser) == 1 && !noBrowser[0] {
|
||||
noBrowser := false
|
||||
if opts != nil {
|
||||
noBrowser = opts.NoBrowser
|
||||
}
|
||||
|
||||
if !noBrowser {
|
||||
fmt.Println("Opening browser for authentication...")
|
||||
|
||||
// Check if browser is available
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available on this system")
|
||||
util.PrintSSHTunnelInstructions(8085)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||
} else {
|
||||
if err := browser.OpenURL(authURL); err != nil {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err)
|
||||
log.Warn(codex.GetUserFriendlyMessage(authErr))
|
||||
util.PrintSSHTunnelInstructions(8085)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||
|
||||
// Log platform info for debugging
|
||||
@@ -273,7 +308,7 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
util.PrintSSHTunnelInstructions(8085)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL)
|
||||
}
|
||||
|
||||
@@ -281,13 +316,60 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
||||
|
||||
// Wait for the authorization code or an error.
|
||||
var authCode string
|
||||
select {
|
||||
case code := <-codeChan:
|
||||
authCode = code
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-time.After(5 * time.Minute): // Timeout
|
||||
return nil, fmt.Errorf("oauth flow timed out")
|
||||
timeoutTimer := time.NewTimer(5 * time.Minute)
|
||||
defer timeoutTimer.Stop()
|
||||
|
||||
var manualPromptTimer *time.Timer
|
||||
var manualPromptC <-chan time.Time
|
||||
if opts != nil && opts.Prompt != nil {
|
||||
manualPromptTimer = time.NewTimer(15 * time.Second)
|
||||
manualPromptC = manualPromptTimer.C
|
||||
defer manualPromptTimer.Stop()
|
||||
}
|
||||
|
||||
waitForCallback:
|
||||
for {
|
||||
select {
|
||||
case code := <-codeChan:
|
||||
authCode = code
|
||||
break waitForCallback
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-manualPromptC:
|
||||
manualPromptC = nil
|
||||
if manualPromptTimer != nil {
|
||||
manualPromptTimer.Stop()
|
||||
}
|
||||
select {
|
||||
case code := <-codeChan:
|
||||
authCode = code
|
||||
break waitForCallback
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
default:
|
||||
}
|
||||
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsed, err := misc.ParseOAuthCallback(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsed == nil {
|
||||
continue
|
||||
}
|
||||
if parsed.Error != "" {
|
||||
return nil, fmt.Errorf("authentication failed via callback: %s", parsed.Error)
|
||||
}
|
||||
if parsed.Code == "" {
|
||||
return nil, fmt.Errorf("code not found in callback")
|
||||
}
|
||||
authCode = parsed.Code
|
||||
break waitForCallback
|
||||
case <-timeoutTimer.C:
|
||||
return nil, fmt.Errorf("oauth flow timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown the server.
|
||||
|
||||
@@ -5,10 +5,12 @@ package kiro
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
|
||||
@@ -40,6 +42,10 @@ type KiroTokenData struct {
|
||||
ClientSecret string `json:"clientSecret,omitempty"`
|
||||
// Email is the user's email address (used for file naming)
|
||||
Email string `json:"email,omitempty"`
|
||||
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
|
||||
StartURL string `json:"startUrl,omitempty"`
|
||||
// Region is the AWS region for IDC authentication (only for IDC auth method)
|
||||
Region string `json:"region,omitempty"`
|
||||
}
|
||||
|
||||
// KiroAuthBundle aggregates authentication data after OAuth flow completion
|
||||
@@ -81,6 +87,87 @@ type KiroModel struct {
|
||||
// KiroIDETokenFile is the default path to Kiro IDE's token file
|
||||
const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json"
|
||||
|
||||
// Default retry configuration for file reading
|
||||
const (
|
||||
defaultTokenReadMaxAttempts = 10 // Maximum retry attempts
|
||||
defaultTokenReadBaseDelay = 50 * time.Millisecond // Base delay between retries
|
||||
)
|
||||
|
||||
// isTransientFileError checks if the error is a transient file access error
|
||||
// that may be resolved by retrying (e.g., file locked by another process on Windows).
|
||||
func isTransientFileError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for OS-level file access errors (Windows sharing violation, etc.)
|
||||
var pathErr *os.PathError
|
||||
if errors.As(err, &pathErr) {
|
||||
// Windows sharing violation (ERROR_SHARING_VIOLATION = 32)
|
||||
// Windows lock violation (ERROR_LOCK_VIOLATION = 33)
|
||||
errStr := pathErr.Err.Error()
|
||||
if strings.Contains(errStr, "being used by another process") ||
|
||||
strings.Contains(errStr, "sharing violation") ||
|
||||
strings.Contains(errStr, "lock violation") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check error message for common transient patterns
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
transientPatterns := []string{
|
||||
"being used by another process",
|
||||
"sharing violation",
|
||||
"lock violation",
|
||||
"access is denied",
|
||||
"unexpected end of json",
|
||||
"unexpected eof",
|
||||
}
|
||||
for _, pattern := range transientPatterns {
|
||||
if strings.Contains(errMsg, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// LoadKiroIDETokenWithRetry loads token data from Kiro IDE's token file with retry logic.
|
||||
// This handles transient file access errors (e.g., file locked by Kiro IDE during write).
|
||||
// maxAttempts: maximum number of retry attempts (default 10 if <= 0)
|
||||
// baseDelay: base delay between retries with exponential backoff (default 50ms if <= 0)
|
||||
func LoadKiroIDETokenWithRetry(maxAttempts int, baseDelay time.Duration) (*KiroTokenData, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = defaultTokenReadMaxAttempts
|
||||
}
|
||||
if baseDelay <= 0 {
|
||||
baseDelay = defaultTokenReadBaseDelay
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
token, err := LoadKiroIDEToken()
|
||||
if err == nil {
|
||||
return token, nil
|
||||
}
|
||||
lastErr = err
|
||||
|
||||
// Only retry for transient errors
|
||||
if !isTransientFileError(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Exponential backoff: delay * 2^attempt, capped at 500ms
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt))
|
||||
if delay > 500*time.Millisecond {
|
||||
delay = 500 * time.Millisecond
|
||||
}
|
||||
time.Sleep(delay)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to read token file after %d attempts: %w", maxAttempts, lastErr)
|
||||
}
|
||||
|
||||
// LoadKiroIDEToken loads token data from Kiro IDE's token file.
|
||||
func LoadKiroIDEToken() (*KiroTokenData, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
@@ -103,6 +190,9 @@ func LoadKiroIDEToken() (*KiroTokenData, error) {
|
||||
return nil, fmt.Errorf("access token is empty in Kiro IDE token file")
|
||||
}
|
||||
|
||||
// Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
|
||||
token.AuthMethod = strings.ToLower(token.AuthMethod)
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
@@ -132,6 +222,9 @@ func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) {
|
||||
return nil, fmt.Errorf("access token is empty in token file")
|
||||
}
|
||||
|
||||
// Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
|
||||
token.AuthMethod = strings.ToLower(token.AuthMethod)
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -280,6 +280,11 @@ func (k *KiroAuth) CreateTokenStorage(tokenData *KiroTokenData) *KiroTokenStorag
|
||||
AuthMethod: tokenData.AuthMethod,
|
||||
Provider: tokenData.Provider,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
ClientID: tokenData.ClientID,
|
||||
ClientSecret: tokenData.ClientSecret,
|
||||
Region: tokenData.Region,
|
||||
StartURL: tokenData.StartURL,
|
||||
Email: tokenData.Email,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -311,4 +316,19 @@ func (k *KiroAuth) UpdateTokenStorage(storage *KiroTokenStorage, tokenData *Kiro
|
||||
storage.AuthMethod = tokenData.AuthMethod
|
||||
storage.Provider = tokenData.Provider
|
||||
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||
if tokenData.ClientID != "" {
|
||||
storage.ClientID = tokenData.ClientID
|
||||
}
|
||||
if tokenData.ClientSecret != "" {
|
||||
storage.ClientSecret = tokenData.ClientSecret
|
||||
}
|
||||
if tokenData.Region != "" {
|
||||
storage.Region = tokenData.Region
|
||||
}
|
||||
if tokenData.StartURL != "" {
|
||||
storage.StartURL = tokenData.StartURL
|
||||
}
|
||||
if tokenData.Email != "" {
|
||||
storage.Email = tokenData.Email
|
||||
}
|
||||
}
|
||||
|
||||
224
internal/auth/kiro/background_refresh.go
Normal file
224
internal/auth/kiro/background_refresh.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
ID string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresAt time.Time
|
||||
LastVerified time.Time
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AuthMethod string
|
||||
Provider string
|
||||
StartURL string
|
||||
Region string
|
||||
}
|
||||
|
||||
type TokenRepository interface {
|
||||
FindOldestUnverified(limit int) []*Token
|
||||
UpdateToken(token *Token) error
|
||||
}
|
||||
|
||||
type RefresherOption func(*BackgroundRefresher)
|
||||
|
||||
func WithInterval(interval time.Duration) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.interval = interval
|
||||
}
|
||||
}
|
||||
|
||||
func WithBatchSize(size int) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.batchSize = size
|
||||
}
|
||||
}
|
||||
|
||||
func WithConcurrency(concurrency int) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.concurrency = concurrency
|
||||
}
|
||||
}
|
||||
|
||||
type BackgroundRefresher struct {
|
||||
interval time.Duration
|
||||
batchSize int
|
||||
concurrency int
|
||||
tokenRepo TokenRepository
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
oauth *KiroOAuth
|
||||
ssoClient *SSOOIDCClient
|
||||
callbackMu sync.RWMutex // 保护回调函数的并发访问
|
||||
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
|
||||
}
|
||||
|
||||
func NewBackgroundRefresher(repo TokenRepository, opts ...RefresherOption) *BackgroundRefresher {
|
||||
r := &BackgroundRefresher{
|
||||
interval: time.Minute,
|
||||
batchSize: 50,
|
||||
concurrency: 10,
|
||||
tokenRepo: repo,
|
||||
stopCh: make(chan struct{}),
|
||||
oauth: nil, // Lazy init - will be set when config available
|
||||
ssoClient: nil, // Lazy init - will be set when config available
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// WithConfig sets the configuration for OAuth and SSO clients.
|
||||
func WithConfig(cfg *config.Config) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.oauth = NewKiroOAuth(cfg)
|
||||
r.ssoClient = NewSSOOIDCClient(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// WithOnTokenRefreshed sets the callback function to be called when a token is successfully refreshed.
|
||||
// The callback receives the token ID (filename) and the new token data.
|
||||
// This allows external components (e.g., Watcher) to be notified of token updates.
|
||||
func WithOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.callbackMu.Lock()
|
||||
r.onTokenRefreshed = callback
|
||||
r.callbackMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) Start(ctx context.Context) {
|
||||
r.wg.Add(1)
|
||||
go func() {
|
||||
defer r.wg.Done()
|
||||
ticker := time.NewTicker(r.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
r.refreshBatch(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-r.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
r.refreshBatch(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) Stop() {
|
||||
close(r.stopCh)
|
||||
r.wg.Wait()
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) refreshBatch(ctx context.Context) {
|
||||
tokens := r.tokenRepo.FindOldestUnverified(r.batchSize)
|
||||
if len(tokens) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sem := semaphore.NewWeighted(int64(r.concurrency))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i, token := range tokens {
|
||||
if i > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-r.stopCh:
|
||||
return
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
if err := sem.Acquire(ctx, 1); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(t *Token) {
|
||||
defer wg.Done()
|
||||
defer sem.Release(1)
|
||||
r.refreshSingle(ctx, t)
|
||||
}(token)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) {
|
||||
var newTokenData *KiroTokenData
|
||||
var err error
|
||||
|
||||
switch token.AuthMethod {
|
||||
case "idc":
|
||||
newTokenData, err = r.ssoClient.RefreshTokenWithRegion(
|
||||
ctx,
|
||||
token.ClientID,
|
||||
token.ClientSecret,
|
||||
token.RefreshToken,
|
||||
token.Region,
|
||||
token.StartURL,
|
||||
)
|
||||
case "builder-id":
|
||||
newTokenData, err = r.ssoClient.RefreshToken(
|
||||
ctx,
|
||||
token.ClientID,
|
||||
token.ClientSecret,
|
||||
token.RefreshToken,
|
||||
)
|
||||
default:
|
||||
newTokenData, err = r.oauth.RefreshToken(ctx, token.RefreshToken)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("failed to refresh token %s: %v", token.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
token.AccessToken = newTokenData.AccessToken
|
||||
token.RefreshToken = newTokenData.RefreshToken
|
||||
token.LastVerified = time.Now()
|
||||
|
||||
if newTokenData.ExpiresAt != "" {
|
||||
if expTime, parseErr := time.Parse(time.RFC3339, newTokenData.ExpiresAt); parseErr == nil {
|
||||
token.ExpiresAt = expTime
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.tokenRepo.UpdateToken(token); err != nil {
|
||||
log.Printf("failed to update token %s: %v", token.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 方案 A: 刷新成功后触发回调,通知 Watcher 更新内存中的 Auth 对象
|
||||
r.callbackMu.RLock()
|
||||
callback := r.onTokenRefreshed
|
||||
r.callbackMu.RUnlock()
|
||||
|
||||
if callback != nil {
|
||||
// 使用 defer recover 隔离回调 panic,防止崩溃整个进程
|
||||
func() {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
log.Printf("background refresh: callback panic for token %s: %v", token.ID, rec)
|
||||
}
|
||||
}()
|
||||
log.Printf("background refresh: notifying token refresh callback for %s", token.ID)
|
||||
callback(token.ID, newTokenData)
|
||||
}()
|
||||
}
|
||||
}
|
||||
166
internal/auth/kiro/codewhisperer_client.go
Normal file
166
internal/auth/kiro/codewhisperer_client.go
Normal file
@@ -0,0 +1,166 @@
|
||||
// Package kiro provides CodeWhisperer API client for fetching user info.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
codeWhispererAPI = "https://codewhisperer.us-east-1.amazonaws.com"
|
||||
kiroVersion = "0.6.18"
|
||||
)
|
||||
|
||||
// CodeWhispererClient handles CodeWhisperer API calls.
|
||||
type CodeWhispererClient struct {
|
||||
httpClient *http.Client
|
||||
machineID string
|
||||
}
|
||||
|
||||
// UsageLimitsResponse represents the getUsageLimits API response.
|
||||
type UsageLimitsResponse struct {
|
||||
DaysUntilReset *int `json:"daysUntilReset,omitempty"`
|
||||
NextDateReset *float64 `json:"nextDateReset,omitempty"`
|
||||
UserInfo *UserInfo `json:"userInfo,omitempty"`
|
||||
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
|
||||
UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList,omitempty"`
|
||||
}
|
||||
|
||||
// UserInfo contains user information from the API.
|
||||
type UserInfo struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
UserID string `json:"userId,omitempty"`
|
||||
}
|
||||
|
||||
// SubscriptionInfo contains subscription details.
|
||||
type SubscriptionInfo struct {
|
||||
SubscriptionTitle string `json:"subscriptionTitle,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
// UsageBreakdown contains usage details.
|
||||
type UsageBreakdown struct {
|
||||
UsageLimit *int `json:"usageLimit,omitempty"`
|
||||
CurrentUsage *int `json:"currentUsage,omitempty"`
|
||||
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision,omitempty"`
|
||||
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision,omitempty"`
|
||||
NextDateReset *float64 `json:"nextDateReset,omitempty"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
ResourceType string `json:"resourceType,omitempty"`
|
||||
}
|
||||
|
||||
// NewCodeWhispererClient creates a new CodeWhisperer client.
|
||||
func NewCodeWhispererClient(cfg *config.Config, machineID string) *CodeWhispererClient {
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
if machineID == "" {
|
||||
machineID = uuid.New().String()
|
||||
}
|
||||
return &CodeWhispererClient{
|
||||
httpClient: client,
|
||||
machineID: machineID,
|
||||
}
|
||||
}
|
||||
|
||||
// generateInvocationID generates a unique invocation ID.
|
||||
func generateInvocationID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// GetUsageLimits fetches usage limits and user info from CodeWhisperer API.
|
||||
// This is the recommended way to get user email after login.
|
||||
func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken string) (*UsageLimitsResponse, error) {
|
||||
url := fmt.Sprintf("%s/getUsageLimits?isEmailRequired=true&origin=AI_EDITOR&resourceType=AGENTIC_REQUEST", codeWhispererAPI)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Set headers to match Kiro IDE
|
||||
xAmzUserAgent := fmt.Sprintf("aws-sdk-js/1.0.0 KiroIDE-%s-%s", kiroVersion, c.machineID)
|
||||
userAgent := fmt.Sprintf("aws-sdk-js/1.0.0 ua/2.1 os/windows lang/js md/nodejs#20.16.0 api/codewhispererruntime#1.0.0 m/E KiroIDE-%s-%s", kiroVersion, c.machineID)
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("x-amz-user-agent", xAmzUserAgent)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
req.Header.Set("amz-sdk-invocation-id", generateInvocationID())
|
||||
req.Header.Set("amz-sdk-request", "attempt=1; max=1")
|
||||
req.Header.Set("Connection", "close")
|
||||
|
||||
log.Debugf("codewhisperer: GET %s", url)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("codewhisperer: status=%d, body=%s", resp.StatusCode, string(body))
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result UsageLimitsResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// FetchUserEmailFromAPI fetches user email using CodeWhisperer getUsageLimits API.
|
||||
// This is more reliable than JWT parsing as it uses the official API.
|
||||
func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessToken string) string {
|
||||
resp, err := c.GetUsageLimits(ctx, accessToken)
|
||||
if err != nil {
|
||||
log.Debugf("codewhisperer: failed to get usage limits: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
if resp.UserInfo != nil && resp.UserInfo.Email != "" {
|
||||
log.Debugf("codewhisperer: got email from API: %s", resp.UserInfo.Email)
|
||||
return resp.UserInfo.Email
|
||||
}
|
||||
|
||||
log.Debugf("codewhisperer: no email in response")
|
||||
return ""
|
||||
}
|
||||
|
||||
// FetchUserEmailWithFallback fetches user email with multiple fallback methods.
|
||||
// Priority: 1. CodeWhisperer API 2. userinfo endpoint 3. JWT parsing
|
||||
func FetchUserEmailWithFallback(ctx context.Context, cfg *config.Config, accessToken string) string {
|
||||
// Method 1: Try CodeWhisperer API (most reliable)
|
||||
cwClient := NewCodeWhispererClient(cfg, "")
|
||||
email := cwClient.FetchUserEmailFromAPI(ctx, accessToken)
|
||||
if email != "" {
|
||||
return email
|
||||
}
|
||||
|
||||
// Method 2: Try SSO OIDC userinfo endpoint
|
||||
ssoClient := NewSSOOIDCClient(cfg)
|
||||
email = ssoClient.FetchUserEmail(ctx, accessToken)
|
||||
if email != "" {
|
||||
return email
|
||||
}
|
||||
|
||||
// Method 3: Fallback to JWT parsing
|
||||
return ExtractEmailFromJWT(accessToken)
|
||||
}
|
||||
112
internal/auth/kiro/cooldown.go
Normal file
112
internal/auth/kiro/cooldown.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
CooldownReason429 = "rate_limit_exceeded"
|
||||
CooldownReasonSuspended = "account_suspended"
|
||||
CooldownReasonQuotaExhausted = "quota_exhausted"
|
||||
|
||||
DefaultShortCooldown = 1 * time.Minute
|
||||
MaxShortCooldown = 5 * time.Minute
|
||||
LongCooldown = 24 * time.Hour
|
||||
)
|
||||
|
||||
type CooldownManager struct {
|
||||
mu sync.RWMutex
|
||||
cooldowns map[string]time.Time
|
||||
reasons map[string]string
|
||||
}
|
||||
|
||||
func NewCooldownManager() *CooldownManager {
|
||||
return &CooldownManager{
|
||||
cooldowns: make(map[string]time.Time),
|
||||
reasons: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) SetCooldown(tokenKey string, duration time.Duration, reason string) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
cm.cooldowns[tokenKey] = time.Now().Add(duration)
|
||||
cm.reasons[tokenKey] = reason
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) IsInCooldown(tokenKey string) bool {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
endTime, exists := cm.cooldowns[tokenKey]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(endTime)
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) GetRemainingCooldown(tokenKey string) time.Duration {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
endTime, exists := cm.cooldowns[tokenKey]
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
remaining := time.Until(endTime)
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) GetCooldownReason(tokenKey string) string {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return cm.reasons[tokenKey]
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) ClearCooldown(tokenKey string) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
delete(cm.cooldowns, tokenKey)
|
||||
delete(cm.reasons, tokenKey)
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) CleanupExpired() {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
now := time.Now()
|
||||
for tokenKey, endTime := range cm.cooldowns {
|
||||
if now.After(endTime) {
|
||||
delete(cm.cooldowns, tokenKey)
|
||||
delete(cm.reasons, tokenKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) StartCleanupRoutine(interval time.Duration, stopCh <-chan struct{}) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cm.CleanupExpired()
|
||||
case <-stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CalculateCooldownFor429(retryCount int) time.Duration {
|
||||
duration := DefaultShortCooldown * time.Duration(1<<retryCount)
|
||||
if duration > MaxShortCooldown {
|
||||
return MaxShortCooldown
|
||||
}
|
||||
return duration
|
||||
}
|
||||
|
||||
func CalculateCooldownUntilNextDay() time.Duration {
|
||||
now := time.Now()
|
||||
nextDay := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location())
|
||||
return time.Until(nextDay)
|
||||
}
|
||||
240
internal/auth/kiro/cooldown_test.go
Normal file
240
internal/auth/kiro/cooldown_test.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewCooldownManager(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
if cm == nil {
|
||||
t.Fatal("expected non-nil CooldownManager")
|
||||
}
|
||||
if cm.cooldowns == nil {
|
||||
t.Error("expected non-nil cooldowns map")
|
||||
}
|
||||
if cm.reasons == nil {
|
||||
t.Error("expected non-nil reasons map")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCooldown(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
|
||||
|
||||
if !cm.IsInCooldown("token1") {
|
||||
t.Error("expected token to be in cooldown")
|
||||
}
|
||||
if cm.GetCooldownReason("token1") != CooldownReason429 {
|
||||
t.Errorf("expected reason %s, got %s", CooldownReason429, cm.GetCooldownReason("token1"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsInCooldown_NotSet(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
if cm.IsInCooldown("nonexistent") {
|
||||
t.Error("expected non-existent token to not be in cooldown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsInCooldown_Expired(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
if cm.IsInCooldown("token1") {
|
||||
t.Error("expected expired cooldown to return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRemainingCooldown(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Second, CooldownReason429)
|
||||
|
||||
remaining := cm.GetRemainingCooldown("token1")
|
||||
if remaining <= 0 || remaining > 1*time.Second {
|
||||
t.Errorf("expected remaining cooldown between 0 and 1s, got %v", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRemainingCooldown_NotSet(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
remaining := cm.GetRemainingCooldown("nonexistent")
|
||||
if remaining != 0 {
|
||||
t.Errorf("expected 0 remaining for non-existent, got %v", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRemainingCooldown_Expired(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
remaining := cm.GetRemainingCooldown("token1")
|
||||
if remaining != 0 {
|
||||
t.Errorf("expected 0 remaining for expired, got %v", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCooldownReason(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
|
||||
|
||||
reason := cm.GetCooldownReason("token1")
|
||||
if reason != CooldownReasonSuspended {
|
||||
t.Errorf("expected reason %s, got %s", CooldownReasonSuspended, reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCooldownReason_NotSet(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
reason := cm.GetCooldownReason("nonexistent")
|
||||
if reason != "" {
|
||||
t.Errorf("expected empty reason for non-existent, got %s", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearCooldown(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
|
||||
cm.ClearCooldown("token1")
|
||||
|
||||
if cm.IsInCooldown("token1") {
|
||||
t.Error("expected cooldown to be cleared")
|
||||
}
|
||||
if cm.GetCooldownReason("token1") != "" {
|
||||
t.Error("expected reason to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearCooldown_NonExistent(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.ClearCooldown("nonexistent")
|
||||
}
|
||||
|
||||
func TestCleanupExpired(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("expired1", 1*time.Millisecond, CooldownReason429)
|
||||
cm.SetCooldown("expired2", 1*time.Millisecond, CooldownReason429)
|
||||
cm.SetCooldown("active", 1*time.Hour, CooldownReason429)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cm.CleanupExpired()
|
||||
|
||||
if cm.GetCooldownReason("expired1") != "" {
|
||||
t.Error("expected expired1 to be cleaned up")
|
||||
}
|
||||
if cm.GetCooldownReason("expired2") != "" {
|
||||
t.Error("expected expired2 to be cleaned up")
|
||||
}
|
||||
if cm.GetCooldownReason("active") != CooldownReason429 {
|
||||
t.Error("expected active to remain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownFor429_FirstRetry(t *testing.T) {
|
||||
duration := CalculateCooldownFor429(0)
|
||||
if duration != DefaultShortCooldown {
|
||||
t.Errorf("expected %v for retry 0, got %v", DefaultShortCooldown, duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownFor429_Exponential(t *testing.T) {
|
||||
d1 := CalculateCooldownFor429(1)
|
||||
d2 := CalculateCooldownFor429(2)
|
||||
|
||||
if d2 <= d1 {
|
||||
t.Errorf("expected d2 > d1, got d1=%v, d2=%v", d1, d2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownFor429_MaxCap(t *testing.T) {
|
||||
duration := CalculateCooldownFor429(10)
|
||||
if duration > MaxShortCooldown {
|
||||
t.Errorf("expected max %v, got %v", MaxShortCooldown, duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownUntilNextDay(t *testing.T) {
|
||||
duration := CalculateCooldownUntilNextDay()
|
||||
if duration <= 0 || duration > 24*time.Hour {
|
||||
t.Errorf("expected duration between 0 and 24h, got %v", duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldownManager_ConcurrentAccess(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
const numGoroutines = 50
|
||||
const numOperations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
tokenKey := "token" + string(rune('a'+id%10))
|
||||
for j := 0; j < numOperations; j++ {
|
||||
switch j % 6 {
|
||||
case 0:
|
||||
cm.SetCooldown(tokenKey, time.Duration(j)*time.Millisecond, CooldownReason429)
|
||||
case 1:
|
||||
cm.IsInCooldown(tokenKey)
|
||||
case 2:
|
||||
cm.GetRemainingCooldown(tokenKey)
|
||||
case 3:
|
||||
cm.GetCooldownReason(tokenKey)
|
||||
case 4:
|
||||
cm.ClearCooldown(tokenKey)
|
||||
case 5:
|
||||
cm.CleanupExpired()
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestCooldownReasonConstants(t *testing.T) {
|
||||
if CooldownReason429 != "rate_limit_exceeded" {
|
||||
t.Errorf("unexpected CooldownReason429: %s", CooldownReason429)
|
||||
}
|
||||
if CooldownReasonSuspended != "account_suspended" {
|
||||
t.Errorf("unexpected CooldownReasonSuspended: %s", CooldownReasonSuspended)
|
||||
}
|
||||
if CooldownReasonQuotaExhausted != "quota_exhausted" {
|
||||
t.Errorf("unexpected CooldownReasonQuotaExhausted: %s", CooldownReasonQuotaExhausted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConstants(t *testing.T) {
|
||||
if DefaultShortCooldown != 1*time.Minute {
|
||||
t.Errorf("unexpected DefaultShortCooldown: %v", DefaultShortCooldown)
|
||||
}
|
||||
if MaxShortCooldown != 5*time.Minute {
|
||||
t.Errorf("unexpected MaxShortCooldown: %v", MaxShortCooldown)
|
||||
}
|
||||
if LongCooldown != 24*time.Hour {
|
||||
t.Errorf("unexpected LongCooldown: %v", LongCooldown)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCooldown_OverwritesPrevious(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Hour, CooldownReason429)
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
|
||||
|
||||
reason := cm.GetCooldownReason("token1")
|
||||
if reason != CooldownReasonSuspended {
|
||||
t.Errorf("expected reason to be overwritten to %s, got %s", CooldownReasonSuspended, reason)
|
||||
}
|
||||
|
||||
remaining := cm.GetRemainingCooldown("token1")
|
||||
if remaining > 1*time.Minute {
|
||||
t.Errorf("expected remaining <= 1 minute, got %v", remaining)
|
||||
}
|
||||
}
|
||||
197
internal/auth/kiro/fingerprint.go
Normal file
197
internal/auth/kiro/fingerprint.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Fingerprint 多维度指纹信息
|
||||
type Fingerprint struct {
|
||||
SDKVersion string // 1.0.20-1.0.27
|
||||
OSType string // darwin/windows/linux
|
||||
OSVersion string // 10.0.22621
|
||||
NodeVersion string // 18.x/20.x/22.x
|
||||
KiroVersion string // 0.3.x-0.8.x
|
||||
KiroHash string // SHA256
|
||||
AcceptLanguage string
|
||||
ScreenResolution string // 1920x1080
|
||||
ColorDepth int // 24
|
||||
HardwareConcurrency int // CPU 核心数
|
||||
TimezoneOffset int
|
||||
}
|
||||
|
||||
// FingerprintManager 指纹管理器
|
||||
type FingerprintManager struct {
|
||||
mu sync.RWMutex
|
||||
fingerprints map[string]*Fingerprint // tokenKey -> fingerprint
|
||||
rng *rand.Rand
|
||||
}
|
||||
|
||||
var (
|
||||
sdkVersions = []string{
|
||||
"1.0.20", "1.0.21", "1.0.22", "1.0.23",
|
||||
"1.0.24", "1.0.25", "1.0.26", "1.0.27",
|
||||
}
|
||||
osTypes = []string{"darwin", "windows", "linux"}
|
||||
osVersions = map[string][]string{
|
||||
"darwin": {"14.0", "14.1", "14.2", "14.3", "14.4", "14.5", "15.0", "15.1"},
|
||||
"windows": {"10.0.19041", "10.0.19042", "10.0.19043", "10.0.19044", "10.0.22621", "10.0.22631"},
|
||||
"linux": {"5.15.0", "6.1.0", "6.2.0", "6.5.0", "6.6.0", "6.8.0"},
|
||||
}
|
||||
nodeVersions = []string{
|
||||
"18.17.0", "18.18.0", "18.19.0", "18.20.0",
|
||||
"20.9.0", "20.10.0", "20.11.0", "20.12.0", "20.13.0",
|
||||
"22.0.0", "22.1.0", "22.2.0", "22.3.0",
|
||||
}
|
||||
kiroVersions = []string{
|
||||
"0.3.0", "0.3.1", "0.4.0", "0.4.1", "0.5.0", "0.5.1",
|
||||
"0.6.0", "0.6.1", "0.7.0", "0.7.1", "0.8.0", "0.8.1",
|
||||
}
|
||||
acceptLanguages = []string{
|
||||
"en-US,en;q=0.9",
|
||||
"en-GB,en;q=0.9",
|
||||
"zh-CN,zh;q=0.9,en;q=0.8",
|
||||
"zh-TW,zh;q=0.9,en;q=0.8",
|
||||
"ja-JP,ja;q=0.9,en;q=0.8",
|
||||
"ko-KR,ko;q=0.9,en;q=0.8",
|
||||
"de-DE,de;q=0.9,en;q=0.8",
|
||||
"fr-FR,fr;q=0.9,en;q=0.8",
|
||||
}
|
||||
screenResolutions = []string{
|
||||
"1920x1080", "2560x1440", "3840x2160",
|
||||
"1366x768", "1440x900", "1680x1050",
|
||||
"2560x1600", "3440x1440",
|
||||
}
|
||||
colorDepths = []int{24, 32}
|
||||
hardwareConcurrencies = []int{4, 6, 8, 10, 12, 16, 20, 24, 32}
|
||||
timezoneOffsets = []int{-480, -420, -360, -300, -240, 0, 60, 120, 480, 540}
|
||||
)
|
||||
|
||||
// NewFingerprintManager 创建指纹管理器
|
||||
func NewFingerprintManager() *FingerprintManager {
|
||||
return &FingerprintManager{
|
||||
fingerprints: make(map[string]*Fingerprint),
|
||||
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
// GetFingerprint 获取或生成 Token 关联的指纹
|
||||
func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
|
||||
fm.mu.RLock()
|
||||
if fp, exists := fm.fingerprints[tokenKey]; exists {
|
||||
fm.mu.RUnlock()
|
||||
return fp
|
||||
}
|
||||
fm.mu.RUnlock()
|
||||
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
|
||||
if fp, exists := fm.fingerprints[tokenKey]; exists {
|
||||
return fp
|
||||
}
|
||||
|
||||
fp := fm.generateFingerprint(tokenKey)
|
||||
fm.fingerprints[tokenKey] = fp
|
||||
return fp
|
||||
}
|
||||
|
||||
// generateFingerprint 生成新的指纹
|
||||
func (fm *FingerprintManager) generateFingerprint(tokenKey string) *Fingerprint {
|
||||
osType := fm.randomChoice(osTypes)
|
||||
osVersion := fm.randomChoice(osVersions[osType])
|
||||
kiroVersion := fm.randomChoice(kiroVersions)
|
||||
|
||||
fp := &Fingerprint{
|
||||
SDKVersion: fm.randomChoice(sdkVersions),
|
||||
OSType: osType,
|
||||
OSVersion: osVersion,
|
||||
NodeVersion: fm.randomChoice(nodeVersions),
|
||||
KiroVersion: kiroVersion,
|
||||
AcceptLanguage: fm.randomChoice(acceptLanguages),
|
||||
ScreenResolution: fm.randomChoice(screenResolutions),
|
||||
ColorDepth: fm.randomIntChoice(colorDepths),
|
||||
HardwareConcurrency: fm.randomIntChoice(hardwareConcurrencies),
|
||||
TimezoneOffset: fm.randomIntChoice(timezoneOffsets),
|
||||
}
|
||||
|
||||
fp.KiroHash = fm.generateKiroHash(tokenKey, kiroVersion, osType)
|
||||
return fp
|
||||
}
|
||||
|
||||
// generateKiroHash 生成 Kiro Hash
|
||||
func (fm *FingerprintManager) generateKiroHash(tokenKey, kiroVersion, osType string) string {
|
||||
data := fmt.Sprintf("%s:%s:%s:%d", tokenKey, kiroVersion, osType, time.Now().UnixNano())
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// randomChoice 随机选择字符串
|
||||
func (fm *FingerprintManager) randomChoice(choices []string) string {
|
||||
return choices[fm.rng.Intn(len(choices))]
|
||||
}
|
||||
|
||||
// randomIntChoice 随机选择整数
|
||||
func (fm *FingerprintManager) randomIntChoice(choices []int) int {
|
||||
return choices[fm.rng.Intn(len(choices))]
|
||||
}
|
||||
|
||||
// ApplyToRequest 将指纹信息应用到 HTTP 请求头
|
||||
func (fp *Fingerprint) ApplyToRequest(req *http.Request) {
|
||||
req.Header.Set("X-Kiro-SDK-Version", fp.SDKVersion)
|
||||
req.Header.Set("X-Kiro-OS-Type", fp.OSType)
|
||||
req.Header.Set("X-Kiro-OS-Version", fp.OSVersion)
|
||||
req.Header.Set("X-Kiro-Node-Version", fp.NodeVersion)
|
||||
req.Header.Set("X-Kiro-Version", fp.KiroVersion)
|
||||
req.Header.Set("X-Kiro-Hash", fp.KiroHash)
|
||||
req.Header.Set("Accept-Language", fp.AcceptLanguage)
|
||||
req.Header.Set("X-Screen-Resolution", fp.ScreenResolution)
|
||||
req.Header.Set("X-Color-Depth", fmt.Sprintf("%d", fp.ColorDepth))
|
||||
req.Header.Set("X-Hardware-Concurrency", fmt.Sprintf("%d", fp.HardwareConcurrency))
|
||||
req.Header.Set("X-Timezone-Offset", fmt.Sprintf("%d", fp.TimezoneOffset))
|
||||
}
|
||||
|
||||
// RemoveFingerprint 移除 Token 关联的指纹
|
||||
func (fm *FingerprintManager) RemoveFingerprint(tokenKey string) {
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
delete(fm.fingerprints, tokenKey)
|
||||
}
|
||||
|
||||
// Count 返回当前管理的指纹数量
|
||||
func (fm *FingerprintManager) Count() int {
|
||||
fm.mu.RLock()
|
||||
defer fm.mu.RUnlock()
|
||||
return len(fm.fingerprints)
|
||||
}
|
||||
|
||||
// BuildUserAgent 构建 User-Agent 字符串 (Kiro IDE 风格)
|
||||
// 格式: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
|
||||
func (fp *Fingerprint) BuildUserAgent() string {
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
|
||||
fp.SDKVersion,
|
||||
fp.OSType,
|
||||
fp.OSVersion,
|
||||
fp.NodeVersion,
|
||||
fp.SDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
|
||||
// BuildAmzUserAgent 构建 X-Amz-User-Agent 字符串
|
||||
// 格式: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
|
||||
func (fp *Fingerprint) BuildAmzUserAgent() string {
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s KiroIDE-%s-%s",
|
||||
fp.SDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
227
internal/auth/kiro/fingerprint_test.go
Normal file
227
internal/auth/kiro/fingerprint_test.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewFingerprintManager(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
if fm == nil {
|
||||
t.Fatal("expected non-nil FingerprintManager")
|
||||
}
|
||||
if fm.fingerprints == nil {
|
||||
t.Error("expected non-nil fingerprints map")
|
||||
}
|
||||
if fm.rng == nil {
|
||||
t.Error("expected non-nil rng")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_NewToken(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
if fp == nil {
|
||||
t.Fatal("expected non-nil Fingerprint")
|
||||
}
|
||||
if fp.SDKVersion == "" {
|
||||
t.Error("expected non-empty SDKVersion")
|
||||
}
|
||||
if fp.OSType == "" {
|
||||
t.Error("expected non-empty OSType")
|
||||
}
|
||||
if fp.OSVersion == "" {
|
||||
t.Error("expected non-empty OSVersion")
|
||||
}
|
||||
if fp.NodeVersion == "" {
|
||||
t.Error("expected non-empty NodeVersion")
|
||||
}
|
||||
if fp.KiroVersion == "" {
|
||||
t.Error("expected non-empty KiroVersion")
|
||||
}
|
||||
if fp.KiroHash == "" {
|
||||
t.Error("expected non-empty KiroHash")
|
||||
}
|
||||
if fp.AcceptLanguage == "" {
|
||||
t.Error("expected non-empty AcceptLanguage")
|
||||
}
|
||||
if fp.ScreenResolution == "" {
|
||||
t.Error("expected non-empty ScreenResolution")
|
||||
}
|
||||
if fp.ColorDepth == 0 {
|
||||
t.Error("expected non-zero ColorDepth")
|
||||
}
|
||||
if fp.HardwareConcurrency == 0 {
|
||||
t.Error("expected non-zero HardwareConcurrency")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_SameTokenReturnsSameFingerprint(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
fp2 := fm.GetFingerprint("token1")
|
||||
|
||||
if fp1 != fp2 {
|
||||
t.Error("expected same fingerprint for same token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_DifferentTokens(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
fp2 := fm.GetFingerprint("token2")
|
||||
|
||||
if fp1 == fp2 {
|
||||
t.Error("expected different fingerprints for different tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveFingerprint(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fm.GetFingerprint("token1")
|
||||
if fm.Count() != 1 {
|
||||
t.Fatalf("expected count 1, got %d", fm.Count())
|
||||
}
|
||||
|
||||
fm.RemoveFingerprint("token1")
|
||||
if fm.Count() != 0 {
|
||||
t.Errorf("expected count 0, got %d", fm.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveFingerprint_NonExistent(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fm.RemoveFingerprint("nonexistent")
|
||||
if fm.Count() != 0 {
|
||||
t.Errorf("expected count 0, got %d", fm.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCount(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
if fm.Count() != 0 {
|
||||
t.Errorf("expected count 0, got %d", fm.Count())
|
||||
}
|
||||
|
||||
fm.GetFingerprint("token1")
|
||||
fm.GetFingerprint("token2")
|
||||
fm.GetFingerprint("token3")
|
||||
|
||||
if fm.Count() != 3 {
|
||||
t.Errorf("expected count 3, got %d", fm.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyToRequest(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
fp.ApplyToRequest(req)
|
||||
|
||||
if req.Header.Get("X-Kiro-SDK-Version") != fp.SDKVersion {
|
||||
t.Error("X-Kiro-SDK-Version header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-OS-Type") != fp.OSType {
|
||||
t.Error("X-Kiro-OS-Type header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-OS-Version") != fp.OSVersion {
|
||||
t.Error("X-Kiro-OS-Version header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-Node-Version") != fp.NodeVersion {
|
||||
t.Error("X-Kiro-Node-Version header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-Version") != fp.KiroVersion {
|
||||
t.Error("X-Kiro-Version header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-Hash") != fp.KiroHash {
|
||||
t.Error("X-Kiro-Hash header mismatch")
|
||||
}
|
||||
if req.Header.Get("Accept-Language") != fp.AcceptLanguage {
|
||||
t.Error("Accept-Language header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Screen-Resolution") != fp.ScreenResolution {
|
||||
t.Error("X-Screen-Resolution header mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_OSVersionMatchesOSType(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
fp := fm.GetFingerprint("token" + string(rune('a'+i)))
|
||||
validVersions := osVersions[fp.OSType]
|
||||
found := false
|
||||
for _, v := range validVersions {
|
||||
if v == fp.OSVersion {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("OS version %s not valid for OS type %s", fp.OSVersion, fp.OSType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
const numGoroutines = 100
|
||||
const numOperations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOperations; j++ {
|
||||
tokenKey := "token" + string(rune('a'+id%26))
|
||||
switch j % 4 {
|
||||
case 0:
|
||||
fm.GetFingerprint(tokenKey)
|
||||
case 1:
|
||||
fm.Count()
|
||||
case 2:
|
||||
fp := fm.GetFingerprint(tokenKey)
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
fp.ApplyToRequest(req)
|
||||
case 3:
|
||||
fm.RemoveFingerprint(tokenKey)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestKiroHashUniqueness(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
hashes := make(map[string]bool)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
fp := fm.GetFingerprint("token" + string(rune(i)))
|
||||
if hashes[fp.KiroHash] {
|
||||
t.Errorf("duplicate KiroHash detected: %s", fp.KiroHash)
|
||||
}
|
||||
hashes[fp.KiroHash] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestKiroHashFormat(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
if len(fp.KiroHash) != 64 {
|
||||
t.Errorf("expected KiroHash length 64 (SHA256 hex), got %d", len(fp.KiroHash))
|
||||
}
|
||||
|
||||
for _, c := range fp.KiroHash {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||
t.Errorf("invalid hex character in KiroHash: %c", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
174
internal/auth/kiro/jitter.go
Normal file
174
internal/auth/kiro/jitter.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Jitter configuration constants
|
||||
const (
|
||||
// JitterPercent is the default percentage of jitter to apply (±30%)
|
||||
JitterPercent = 0.30
|
||||
|
||||
// Human-like delay ranges
|
||||
ShortDelayMin = 50 * time.Millisecond // Minimum for rapid consecutive operations
|
||||
ShortDelayMax = 200 * time.Millisecond // Maximum for rapid consecutive operations
|
||||
NormalDelayMin = 1 * time.Second // Minimum for normal thinking time
|
||||
NormalDelayMax = 3 * time.Second // Maximum for normal thinking time
|
||||
LongDelayMin = 5 * time.Second // Minimum for reading/resting
|
||||
LongDelayMax = 10 * time.Second // Maximum for reading/resting
|
||||
|
||||
// Probability thresholds for human-like behavior
|
||||
ShortDelayProbability = 0.20 // 20% chance of short delay (consecutive ops)
|
||||
LongDelayProbability = 0.05 // 5% chance of long delay (reading/resting)
|
||||
NormalDelayProbability = 0.75 // 75% chance of normal delay (thinking)
|
||||
)
|
||||
|
||||
var (
|
||||
jitterRand *rand.Rand
|
||||
jitterRandOnce sync.Once
|
||||
jitterMu sync.Mutex
|
||||
lastRequestTime time.Time
|
||||
)
|
||||
|
||||
// initJitterRand initializes the random number generator for jitter calculations.
|
||||
// Uses a time-based seed for unpredictable but reproducible randomness.
|
||||
func initJitterRand() {
|
||||
jitterRandOnce.Do(func() {
|
||||
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
})
|
||||
}
|
||||
|
||||
// RandomDelay generates a random delay between min and max duration.
|
||||
// Thread-safe implementation using mutex protection.
|
||||
func RandomDelay(min, max time.Duration) time.Duration {
|
||||
initJitterRand()
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
|
||||
if min >= max {
|
||||
return min
|
||||
}
|
||||
|
||||
rangeMs := max.Milliseconds() - min.Milliseconds()
|
||||
randomMs := jitterRand.Int63n(rangeMs)
|
||||
return min + time.Duration(randomMs)*time.Millisecond
|
||||
}
|
||||
|
||||
// JitterDelay adds jitter to a base delay.
|
||||
// Applies ±jitterPercent variation to the base delay.
|
||||
// For example, JitterDelay(1*time.Second, 0.30) returns a value between 700ms and 1300ms.
|
||||
func JitterDelay(baseDelay time.Duration, jitterPercent float64) time.Duration {
|
||||
initJitterRand()
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
|
||||
if jitterPercent <= 0 || jitterPercent > 1 {
|
||||
jitterPercent = JitterPercent
|
||||
}
|
||||
|
||||
// Calculate jitter range: base * jitterPercent
|
||||
jitterRange := float64(baseDelay) * jitterPercent
|
||||
|
||||
// Generate random value in range [-jitterRange, +jitterRange]
|
||||
jitter := (jitterRand.Float64()*2 - 1) * jitterRange
|
||||
|
||||
result := time.Duration(float64(baseDelay) + jitter)
|
||||
if result < 0 {
|
||||
return 0
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// JitterDelayDefault applies the default ±30% jitter to a base delay.
|
||||
func JitterDelayDefault(baseDelay time.Duration) time.Duration {
|
||||
return JitterDelay(baseDelay, JitterPercent)
|
||||
}
|
||||
|
||||
// HumanLikeDelay generates a delay that mimics human behavior patterns.
|
||||
// The delay is selected based on probability distribution:
|
||||
// - 20% chance: Short delay (50-200ms) - simulates consecutive rapid operations
|
||||
// - 75% chance: Normal delay (1-3s) - simulates thinking/reading time
|
||||
// - 5% chance: Long delay (5-10s) - simulates breaks/reading longer content
|
||||
//
|
||||
// Returns the delay duration (caller should call time.Sleep with this value).
|
||||
func HumanLikeDelay() time.Duration {
|
||||
initJitterRand()
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
|
||||
// Track time since last request for adaptive behavior
|
||||
now := time.Now()
|
||||
timeSinceLastRequest := now.Sub(lastRequestTime)
|
||||
lastRequestTime = now
|
||||
|
||||
// If requests are very close together, use short delay
|
||||
if timeSinceLastRequest < 500*time.Millisecond && timeSinceLastRequest > 0 {
|
||||
rangeMs := ShortDelayMax.Milliseconds() - ShortDelayMin.Milliseconds()
|
||||
randomMs := jitterRand.Int63n(rangeMs)
|
||||
return ShortDelayMin + time.Duration(randomMs)*time.Millisecond
|
||||
}
|
||||
|
||||
// Otherwise, use probability-based selection
|
||||
roll := jitterRand.Float64()
|
||||
|
||||
var min, max time.Duration
|
||||
switch {
|
||||
case roll < ShortDelayProbability:
|
||||
// Short delay - consecutive operations
|
||||
min, max = ShortDelayMin, ShortDelayMax
|
||||
case roll < ShortDelayProbability+LongDelayProbability:
|
||||
// Long delay - reading/resting
|
||||
min, max = LongDelayMin, LongDelayMax
|
||||
default:
|
||||
// Normal delay - thinking time
|
||||
min, max = NormalDelayMin, NormalDelayMax
|
||||
}
|
||||
|
||||
rangeMs := max.Milliseconds() - min.Milliseconds()
|
||||
randomMs := jitterRand.Int63n(rangeMs)
|
||||
return min + time.Duration(randomMs)*time.Millisecond
|
||||
}
|
||||
|
||||
// ApplyHumanLikeDelay applies human-like delay by sleeping.
|
||||
// This is a convenience function that combines HumanLikeDelay with time.Sleep.
|
||||
func ApplyHumanLikeDelay() {
|
||||
delay := HumanLikeDelay()
|
||||
if delay > 0 {
|
||||
time.Sleep(delay)
|
||||
}
|
||||
}
|
||||
|
||||
// ExponentialBackoffWithJitter calculates retry delay using exponential backoff with jitter.
|
||||
// Formula: min(baseDelay * 2^attempt + jitter, maxDelay)
|
||||
// This helps prevent thundering herd problem when multiple clients retry simultaneously.
|
||||
func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
|
||||
if attempt < 0 {
|
||||
attempt = 0
|
||||
}
|
||||
|
||||
// Calculate exponential backoff: baseDelay * 2^attempt
|
||||
backoff := baseDelay * time.Duration(1<<uint(attempt))
|
||||
if backoff > maxDelay {
|
||||
backoff = maxDelay
|
||||
}
|
||||
|
||||
// Add ±30% jitter
|
||||
return JitterDelay(backoff, JitterPercent)
|
||||
}
|
||||
|
||||
// ShouldSkipDelay determines if delay should be skipped based on context.
|
||||
// Returns true for streaming responses, WebSocket connections, etc.
|
||||
// This function can be extended to check additional skip conditions.
|
||||
func ShouldSkipDelay(isStreaming bool) bool {
|
||||
return isStreaming
|
||||
}
|
||||
|
||||
// ResetLastRequestTime resets the last request time tracker.
|
||||
// Useful for testing or when starting a new session.
|
||||
func ResetLastRequestTime() {
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
lastRequestTime = time.Time{}
|
||||
}
|
||||
187
internal/auth/kiro/metrics.go
Normal file
187
internal/auth/kiro/metrics.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenMetrics holds performance metrics for a single token.
|
||||
type TokenMetrics struct {
|
||||
SuccessRate float64 // Success rate (0.0 - 1.0)
|
||||
AvgLatency float64 // Average latency in milliseconds
|
||||
QuotaRemaining float64 // Remaining quota (0.0 - 1.0)
|
||||
LastUsed time.Time // Last usage timestamp
|
||||
FailCount int // Consecutive failure count
|
||||
TotalRequests int // Total request count
|
||||
successCount int // Internal: successful request count
|
||||
totalLatency float64 // Internal: cumulative latency
|
||||
}
|
||||
|
||||
// TokenScorer manages token metrics and scoring.
|
||||
type TokenScorer struct {
|
||||
mu sync.RWMutex
|
||||
metrics map[string]*TokenMetrics
|
||||
|
||||
// Scoring weights
|
||||
successRateWeight float64
|
||||
quotaWeight float64
|
||||
latencyWeight float64
|
||||
lastUsedWeight float64
|
||||
failPenaltyMultiplier float64
|
||||
}
|
||||
|
||||
// NewTokenScorer creates a new TokenScorer with default weights.
|
||||
func NewTokenScorer() *TokenScorer {
|
||||
return &TokenScorer{
|
||||
metrics: make(map[string]*TokenMetrics),
|
||||
successRateWeight: 0.4,
|
||||
quotaWeight: 0.25,
|
||||
latencyWeight: 0.2,
|
||||
lastUsedWeight: 0.15,
|
||||
failPenaltyMultiplier: 0.1,
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateMetrics returns existing metrics or creates new ones.
|
||||
func (s *TokenScorer) getOrCreateMetrics(tokenKey string) *TokenMetrics {
|
||||
if m, ok := s.metrics[tokenKey]; ok {
|
||||
return m
|
||||
}
|
||||
m := &TokenMetrics{
|
||||
SuccessRate: 1.0,
|
||||
QuotaRemaining: 1.0,
|
||||
}
|
||||
s.metrics[tokenKey] = m
|
||||
return m
|
||||
}
|
||||
|
||||
// RecordRequest records the result of a request for a token.
|
||||
func (s *TokenScorer) RecordRequest(tokenKey string, success bool, latency time.Duration) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
m := s.getOrCreateMetrics(tokenKey)
|
||||
m.TotalRequests++
|
||||
m.LastUsed = time.Now()
|
||||
m.totalLatency += float64(latency.Milliseconds())
|
||||
|
||||
if success {
|
||||
m.successCount++
|
||||
m.FailCount = 0
|
||||
} else {
|
||||
m.FailCount++
|
||||
}
|
||||
|
||||
// Update derived metrics
|
||||
if m.TotalRequests > 0 {
|
||||
m.SuccessRate = float64(m.successCount) / float64(m.TotalRequests)
|
||||
m.AvgLatency = m.totalLatency / float64(m.TotalRequests)
|
||||
}
|
||||
}
|
||||
|
||||
// SetQuotaRemaining updates the remaining quota for a token.
|
||||
func (s *TokenScorer) SetQuotaRemaining(tokenKey string, quota float64) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
m := s.getOrCreateMetrics(tokenKey)
|
||||
m.QuotaRemaining = quota
|
||||
}
|
||||
|
||||
// GetMetrics returns a copy of the metrics for a token.
|
||||
func (s *TokenScorer) GetMetrics(tokenKey string) *TokenMetrics {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if m, ok := s.metrics[tokenKey]; ok {
|
||||
copy := *m
|
||||
return ©
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CalculateScore computes the score for a token (higher is better).
|
||||
func (s *TokenScorer) CalculateScore(tokenKey string) float64 {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
m, ok := s.metrics[tokenKey]
|
||||
if !ok {
|
||||
return 1.0 // New tokens get a high initial score
|
||||
}
|
||||
|
||||
// Success rate component (0-1)
|
||||
successScore := m.SuccessRate
|
||||
|
||||
// Quota component (0-1)
|
||||
quotaScore := m.QuotaRemaining
|
||||
|
||||
// Latency component (normalized, lower is better)
|
||||
// Using exponential decay: score = e^(-latency/1000)
|
||||
// 1000ms latency -> ~0.37 score, 100ms -> ~0.90 score
|
||||
latencyScore := math.Exp(-m.AvgLatency / 1000.0)
|
||||
if m.TotalRequests == 0 {
|
||||
latencyScore = 1.0
|
||||
}
|
||||
|
||||
// Last used component (prefer tokens not recently used)
|
||||
// Score increases as time since last use increases
|
||||
timeSinceUse := time.Since(m.LastUsed).Seconds()
|
||||
// Normalize: 60 seconds -> ~0.63 score, 0 seconds -> 0 score
|
||||
lastUsedScore := 1.0 - math.Exp(-timeSinceUse/60.0)
|
||||
if m.LastUsed.IsZero() {
|
||||
lastUsedScore = 1.0
|
||||
}
|
||||
|
||||
// Calculate weighted score
|
||||
score := s.successRateWeight*successScore +
|
||||
s.quotaWeight*quotaScore +
|
||||
s.latencyWeight*latencyScore +
|
||||
s.lastUsedWeight*lastUsedScore
|
||||
|
||||
// Apply consecutive failure penalty
|
||||
if m.FailCount > 0 {
|
||||
penalty := s.failPenaltyMultiplier * float64(m.FailCount)
|
||||
score = score * math.Max(0, 1.0-penalty)
|
||||
}
|
||||
|
||||
return score
|
||||
}
|
||||
|
||||
// SelectBestToken selects the token with the highest score.
|
||||
func (s *TokenScorer) SelectBestToken(tokens []string) string {
|
||||
if len(tokens) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(tokens) == 1 {
|
||||
return tokens[0]
|
||||
}
|
||||
|
||||
bestToken := tokens[0]
|
||||
bestScore := s.CalculateScore(tokens[0])
|
||||
|
||||
for _, token := range tokens[1:] {
|
||||
score := s.CalculateScore(token)
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
bestToken = token
|
||||
}
|
||||
}
|
||||
|
||||
return bestToken
|
||||
}
|
||||
|
||||
// ResetMetrics clears all metrics for a token.
|
||||
func (s *TokenScorer) ResetMetrics(tokenKey string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.metrics, tokenKey)
|
||||
}
|
||||
|
||||
// ResetAllMetrics clears all stored metrics.
|
||||
func (s *TokenScorer) ResetAllMetrics() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.metrics = make(map[string]*TokenMetrics)
|
||||
}
|
||||
301
internal/auth/kiro/metrics_test.go
Normal file
301
internal/auth/kiro/metrics_test.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewTokenScorer(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil TokenScorer")
|
||||
}
|
||||
if s.metrics == nil {
|
||||
t.Error("expected non-nil metrics map")
|
||||
}
|
||||
if s.successRateWeight != 0.4 {
|
||||
t.Errorf("expected successRateWeight 0.4, got %f", s.successRateWeight)
|
||||
}
|
||||
if s.quotaWeight != 0.25 {
|
||||
t.Errorf("expected quotaWeight 0.25, got %f", s.quotaWeight)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_Success(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m == nil {
|
||||
t.Fatal("expected non-nil metrics")
|
||||
}
|
||||
if m.TotalRequests != 1 {
|
||||
t.Errorf("expected TotalRequests 1, got %d", m.TotalRequests)
|
||||
}
|
||||
if m.SuccessRate != 1.0 {
|
||||
t.Errorf("expected SuccessRate 1.0, got %f", m.SuccessRate)
|
||||
}
|
||||
if m.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0, got %d", m.FailCount)
|
||||
}
|
||||
if m.AvgLatency != 100 {
|
||||
t.Errorf("expected AvgLatency 100, got %f", m.AvgLatency)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_Failure(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", false, 200*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.SuccessRate != 0.0 {
|
||||
t.Errorf("expected SuccessRate 0.0, got %f", m.SuccessRate)
|
||||
}
|
||||
if m.FailCount != 1 {
|
||||
t.Errorf("expected FailCount 1, got %d", m.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_MixedResults(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.TotalRequests != 4 {
|
||||
t.Errorf("expected TotalRequests 4, got %d", m.TotalRequests)
|
||||
}
|
||||
if m.SuccessRate != 0.75 {
|
||||
t.Errorf("expected SuccessRate 0.75, got %f", m.SuccessRate)
|
||||
}
|
||||
if m.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0 (reset on success), got %d", m.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_ConsecutiveFailures(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.FailCount != 3 {
|
||||
t.Errorf("expected FailCount 3, got %d", m.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetQuotaRemaining(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.SetQuotaRemaining("token1", 0.5)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.QuotaRemaining != 0.5 {
|
||||
t.Errorf("expected QuotaRemaining 0.5, got %f", m.QuotaRemaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMetrics_NonExistent(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
m := s.GetMetrics("nonexistent")
|
||||
if m != nil {
|
||||
t.Error("expected nil metrics for non-existent token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMetrics_ReturnsCopy(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m1 := s.GetMetrics("token1")
|
||||
m1.TotalRequests = 999
|
||||
|
||||
m2 := s.GetMetrics("token1")
|
||||
if m2.TotalRequests == 999 {
|
||||
t.Error("GetMetrics should return a copy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_NewToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
score := s.CalculateScore("newtoken")
|
||||
if score != 1.0 {
|
||||
t.Errorf("expected score 1.0 for new token, got %f", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_PerfectToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 50*time.Millisecond)
|
||||
s.SetQuotaRemaining("token1", 1.0)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
score := s.CalculateScore("token1")
|
||||
if score < 0.5 || score > 1.0 {
|
||||
t.Errorf("expected high score for perfect token, got %f", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_FailedToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
for i := 0; i < 5; i++ {
|
||||
s.RecordRequest("token1", false, 1000*time.Millisecond)
|
||||
}
|
||||
s.SetQuotaRemaining("token1", 0.1)
|
||||
|
||||
score := s.CalculateScore("token1")
|
||||
if score > 0.5 {
|
||||
t.Errorf("expected low score for failed token, got %f", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_FailPenalty(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
scoreNoFail := s.CalculateScore("token1")
|
||||
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
scoreWithFail := s.CalculateScore("token1")
|
||||
|
||||
if scoreWithFail >= scoreNoFail {
|
||||
t.Errorf("expected lower score with consecutive failures: noFail=%f, withFail=%f", scoreNoFail, scoreWithFail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBestToken_Empty(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
best := s.SelectBestToken([]string{})
|
||||
if best != "" {
|
||||
t.Errorf("expected empty string for empty tokens, got %s", best)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBestToken_SingleToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
best := s.SelectBestToken([]string{"token1"})
|
||||
if best != "token1" {
|
||||
t.Errorf("expected token1, got %s", best)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBestToken_MultipleTokens(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
|
||||
s.RecordRequest("bad", false, 1000*time.Millisecond)
|
||||
s.RecordRequest("bad", false, 1000*time.Millisecond)
|
||||
s.SetQuotaRemaining("bad", 0.1)
|
||||
|
||||
s.RecordRequest("good", true, 50*time.Millisecond)
|
||||
s.SetQuotaRemaining("good", 0.9)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
best := s.SelectBestToken([]string{"bad", "good"})
|
||||
if best != "good" {
|
||||
t.Errorf("expected good token to be selected, got %s", best)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetMetrics(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.ResetMetrics("token1")
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m != nil {
|
||||
t.Error("expected nil metrics after reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetAllMetrics(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token2", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token3", true, 100*time.Millisecond)
|
||||
|
||||
s.ResetAllMetrics()
|
||||
|
||||
if s.GetMetrics("token1") != nil {
|
||||
t.Error("expected nil metrics for token1 after reset all")
|
||||
}
|
||||
if s.GetMetrics("token2") != nil {
|
||||
t.Error("expected nil metrics for token2 after reset all")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenScorer_ConcurrentAccess(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
const numGoroutines = 50
|
||||
const numOperations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
tokenKey := "token" + string(rune('a'+id%10))
|
||||
for j := 0; j < numOperations; j++ {
|
||||
switch j % 6 {
|
||||
case 0:
|
||||
s.RecordRequest(tokenKey, j%2 == 0, time.Duration(j)*time.Millisecond)
|
||||
case 1:
|
||||
s.SetQuotaRemaining(tokenKey, float64(j%100)/100)
|
||||
case 2:
|
||||
s.GetMetrics(tokenKey)
|
||||
case 3:
|
||||
s.CalculateScore(tokenKey)
|
||||
case 4:
|
||||
s.SelectBestToken([]string{tokenKey, "token_x", "token_y"})
|
||||
case 5:
|
||||
if j%20 == 0 {
|
||||
s.ResetMetrics(tokenKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAvgLatencyCalculation(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 200*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 300*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.AvgLatency != 200 {
|
||||
t.Errorf("expected AvgLatency 200, got %f", m.AvgLatency)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLastUsedUpdated(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
before := time.Now()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.LastUsed.Before(before) {
|
||||
t.Error("expected LastUsed to be after test start time")
|
||||
}
|
||||
if m.LastUsed.After(time.Now()) {
|
||||
t.Error("expected LastUsed to be before or equal to now")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultQuotaForNewToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.QuotaRemaining != 1.0 {
|
||||
t.Errorf("expected default QuotaRemaining 1.0, got %f", m.QuotaRemaining)
|
||||
}
|
||||
}
|
||||
@@ -163,6 +163,13 @@ func (o *KiroOAuth) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, err
|
||||
return ssoClient.LoginWithBuilderID(ctx)
|
||||
}
|
||||
|
||||
// LoginWithBuilderIDAuthCode performs OAuth login with AWS Builder ID using authorization code flow.
|
||||
// This provides a better UX than device code flow as it uses automatic browser callback.
|
||||
func (o *KiroOAuth) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTokenData, error) {
|
||||
ssoClient := NewSSOOIDCClient(o.cfg)
|
||||
return ssoClient.LoginWithBuilderIDAuthCode(ctx)
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges the authorization code for tokens.
|
||||
func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier, redirectURI string) (*KiroTokenData, error) {
|
||||
payload := map[string]string{
|
||||
@@ -220,6 +227,7 @@ func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: "", // Caller should preserve original provider
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -278,6 +286,7 @@ func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*Kir
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: "", // Caller should preserve original provider
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
982
internal/auth/kiro/oauth_web.go
Normal file
982
internal/auth/kiro/oauth_web.go
Normal file
@@ -0,0 +1,982 @@
|
||||
// Package kiro provides OAuth Web authentication for Kiro.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultSessionExpiry = 10 * time.Minute
|
||||
pollIntervalSeconds = 5
|
||||
)
|
||||
|
||||
type authSessionStatus string
|
||||
|
||||
const (
|
||||
statusPending authSessionStatus = "pending"
|
||||
statusSuccess authSessionStatus = "success"
|
||||
statusFailed authSessionStatus = "failed"
|
||||
)
|
||||
|
||||
type webAuthSession struct {
|
||||
stateID string
|
||||
deviceCode string
|
||||
userCode string
|
||||
authURL string
|
||||
verificationURI string
|
||||
expiresIn int
|
||||
interval int
|
||||
status authSessionStatus
|
||||
startedAt time.Time
|
||||
completedAt time.Time
|
||||
expiresAt time.Time
|
||||
error string
|
||||
tokenData *KiroTokenData
|
||||
ssoClient *SSOOIDCClient
|
||||
clientID string
|
||||
clientSecret string
|
||||
region string
|
||||
cancelFunc context.CancelFunc
|
||||
authMethod string // "google", "github", "builder-id", "idc"
|
||||
startURL string // Used for IDC
|
||||
codeVerifier string // Used for social auth PKCE
|
||||
codeChallenge string // Used for social auth PKCE
|
||||
}
|
||||
|
||||
type OAuthWebHandler struct {
|
||||
cfg *config.Config
|
||||
sessions map[string]*webAuthSession
|
||||
mu sync.RWMutex
|
||||
onTokenObtained func(*KiroTokenData)
|
||||
}
|
||||
|
||||
func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
|
||||
return &OAuthWebHandler{
|
||||
cfg: cfg,
|
||||
sessions: make(map[string]*webAuthSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) SetTokenCallback(callback func(*KiroTokenData)) {
|
||||
h.onTokenObtained = callback
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) RegisterRoutes(router gin.IRouter) {
|
||||
oauth := router.Group("/v0/oauth/kiro")
|
||||
{
|
||||
oauth.GET("", h.handleSelect)
|
||||
oauth.GET("/start", h.handleStart)
|
||||
oauth.GET("/callback", h.handleCallback)
|
||||
oauth.GET("/social/callback", h.handleSocialCallback)
|
||||
oauth.GET("/status", h.handleStatus)
|
||||
oauth.POST("/import", h.handleImportToken)
|
||||
oauth.POST("/refresh", h.handleManualRefresh)
|
||||
}
|
||||
}
|
||||
|
||||
func generateStateID() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleSelect(c *gin.Context) {
|
||||
h.renderSelectPage(c)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleStart(c *gin.Context) {
|
||||
method := c.Query("method")
|
||||
|
||||
if method == "" {
|
||||
c.Redirect(http.StatusFound, "/v0/oauth/kiro")
|
||||
return
|
||||
}
|
||||
|
||||
switch method {
|
||||
case "google", "github":
|
||||
// Google/GitHub social login is not supported for third-party apps
|
||||
// due to AWS Cognito redirect_uri restrictions
|
||||
h.renderError(c, "Google/GitHub login is not available for third-party applications. Please use AWS Builder ID or import your token from Kiro IDE.")
|
||||
case "builder-id":
|
||||
h.startBuilderIDAuth(c)
|
||||
case "idc":
|
||||
h.startIDCAuth(c)
|
||||
default:
|
||||
h.renderError(c, fmt.Sprintf("Unknown authentication method: %s", method))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) startSocialAuth(c *gin.Context, method string) {
|
||||
stateID, err := generateStateID()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
codeVerifier, codeChallenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate PKCE parameters")
|
||||
return
|
||||
}
|
||||
|
||||
socialClient := NewSocialAuthClient(h.cfg)
|
||||
|
||||
var provider string
|
||||
if method == "google" {
|
||||
provider = string(ProviderGoogle)
|
||||
} else {
|
||||
provider = string(ProviderGitHub)
|
||||
}
|
||||
|
||||
redirectURI := h.getSocialCallbackURL(c)
|
||||
authURL := socialClient.buildLoginURL(provider, redirectURI, codeChallenge, stateID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
|
||||
session := &webAuthSession{
|
||||
stateID: stateID,
|
||||
authMethod: method,
|
||||
authURL: authURL,
|
||||
status: statusPending,
|
||||
startedAt: time.Now(),
|
||||
expiresIn: 600,
|
||||
codeVerifier: codeVerifier,
|
||||
codeChallenge: codeChallenge,
|
||||
region: "us-east-1",
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[stateID] = session
|
||||
h.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
h.mu.Lock()
|
||||
if session.status == statusPending {
|
||||
session.status = statusFailed
|
||||
session.error = "Authentication timed out"
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}()
|
||||
|
||||
c.Redirect(http.StatusFound, authURL)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) getSocialCallbackURL(c *gin.Context) string {
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
return fmt.Sprintf("%s://%s/v0/oauth/kiro/social/callback", scheme, c.Request.Host)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) startBuilderIDAuth(c *gin.Context) {
|
||||
stateID, err := generateStateID()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
region := defaultIDCRegion
|
||||
startURL := builderIDStartURL
|
||||
|
||||
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||
|
||||
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to register client: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
|
||||
c.Request.Context(),
|
||||
regResp.ClientID,
|
||||
regResp.ClientSecret,
|
||||
startURL,
|
||||
region,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
|
||||
|
||||
session := &webAuthSession{
|
||||
stateID: stateID,
|
||||
deviceCode: authResp.DeviceCode,
|
||||
userCode: authResp.UserCode,
|
||||
authURL: authResp.VerificationURIComplete,
|
||||
verificationURI: authResp.VerificationURI,
|
||||
expiresIn: authResp.ExpiresIn,
|
||||
interval: authResp.Interval,
|
||||
status: statusPending,
|
||||
startedAt: time.Now(),
|
||||
ssoClient: ssoClient,
|
||||
clientID: regResp.ClientID,
|
||||
clientSecret: regResp.ClientSecret,
|
||||
region: region,
|
||||
authMethod: "builder-id",
|
||||
startURL: startURL,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[stateID] = session
|
||||
h.mu.Unlock()
|
||||
|
||||
go h.pollForToken(ctx, session)
|
||||
|
||||
h.renderStartPage(c, session)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) startIDCAuth(c *gin.Context) {
|
||||
startURL := c.Query("startUrl")
|
||||
region := c.Query("region")
|
||||
|
||||
if startURL == "" {
|
||||
h.renderError(c, "Missing startUrl parameter for IDC authentication")
|
||||
return
|
||||
}
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
|
||||
stateID, err := generateStateID()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||
|
||||
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to register client: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
|
||||
c.Request.Context(),
|
||||
regResp.ClientID,
|
||||
regResp.ClientSecret,
|
||||
startURL,
|
||||
region,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
|
||||
|
||||
session := &webAuthSession{
|
||||
stateID: stateID,
|
||||
deviceCode: authResp.DeviceCode,
|
||||
userCode: authResp.UserCode,
|
||||
authURL: authResp.VerificationURIComplete,
|
||||
verificationURI: authResp.VerificationURI,
|
||||
expiresIn: authResp.ExpiresIn,
|
||||
interval: authResp.Interval,
|
||||
status: statusPending,
|
||||
startedAt: time.Now(),
|
||||
ssoClient: ssoClient,
|
||||
clientID: regResp.ClientID,
|
||||
clientSecret: regResp.ClientSecret,
|
||||
region: region,
|
||||
authMethod: "idc",
|
||||
startURL: startURL,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[stateID] = session
|
||||
h.mu.Unlock()
|
||||
|
||||
go h.pollForToken(ctx, session)
|
||||
|
||||
h.renderStartPage(c, session)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSession) {
|
||||
defer session.cancelFunc()
|
||||
|
||||
interval := time.Duration(session.interval) * time.Second
|
||||
if interval < time.Duration(pollIntervalSeconds)*time.Second {
|
||||
interval = time.Duration(pollIntervalSeconds) * time.Second
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
h.mu.Lock()
|
||||
if session.status == statusPending {
|
||||
session.status = statusFailed
|
||||
session.error = "Authentication timed out"
|
||||
}
|
||||
h.mu.Unlock()
|
||||
return
|
||||
case <-ticker.C:
|
||||
tokenResp, err := h.ssoClient(session).CreateTokenWithRegion(
|
||||
ctx,
|
||||
session.clientID,
|
||||
session.clientSecret,
|
||||
session.deviceCode,
|
||||
session.region,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
if errStr == ErrAuthorizationPending.Error() {
|
||||
continue
|
||||
}
|
||||
if errStr == ErrSlowDown.Error() {
|
||||
interval += 5 * time.Second
|
||||
ticker.Reset(interval)
|
||||
continue
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusFailed
|
||||
session.error = errStr
|
||||
session.completedAt = time.Now()
|
||||
h.mu.Unlock()
|
||||
|
||||
log.Errorf("OAuth Web: token polling failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||
profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken)
|
||||
|
||||
tokenData := &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: profileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: session.authMethod,
|
||||
Provider: "AWS",
|
||||
ClientID: session.clientID,
|
||||
ClientSecret: session.clientSecret,
|
||||
Email: email,
|
||||
Region: session.region,
|
||||
StartURL: session.startURL,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusSuccess
|
||||
session.completedAt = time.Now()
|
||||
session.expiresAt = expiresAt
|
||||
session.tokenData = tokenData
|
||||
h.mu.Unlock()
|
||||
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
|
||||
// Save token to file
|
||||
h.saveTokenToFile(tokenData)
|
||||
|
||||
log.Infof("OAuth Web: authentication successful for %s", email)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// saveTokenToFile saves the token data to the auth directory
|
||||
func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
|
||||
// Get auth directory from config or use default
|
||||
authDir := ""
|
||||
if h.cfg != nil && h.cfg.AuthDir != "" {
|
||||
var err error
|
||||
authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default location
|
||||
if authDir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to get home directory: %v", err)
|
||||
return
|
||||
}
|
||||
authDir = filepath.Join(home, ".cli-proxy-api")
|
||||
}
|
||||
|
||||
// Create directory if not exists
|
||||
if err := os.MkdirAll(authDir, 0700); err != nil {
|
||||
log.Errorf("OAuth Web: failed to create auth directory: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate filename based on auth method
|
||||
// Format: kiro-{authMethod}.json or kiro-{authMethod}-{email}.json
|
||||
fileName := fmt.Sprintf("kiro-%s.json", tokenData.AuthMethod)
|
||||
if tokenData.Email != "" {
|
||||
// Sanitize email for filename (replace @ and . with -)
|
||||
sanitizedEmail := tokenData.Email
|
||||
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, "@", "-")
|
||||
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, ".", "-")
|
||||
fileName = fmt.Sprintf("kiro-%s-%s.json", tokenData.AuthMethod, sanitizedEmail)
|
||||
}
|
||||
|
||||
authFilePath := filepath.Join(authDir, fileName)
|
||||
|
||||
// Convert to storage format and save
|
||||
storage := &KiroTokenStorage{
|
||||
Type: "kiro",
|
||||
AccessToken: tokenData.AccessToken,
|
||||
RefreshToken: tokenData.RefreshToken,
|
||||
ProfileArn: tokenData.ProfileArn,
|
||||
ExpiresAt: tokenData.ExpiresAt,
|
||||
AuthMethod: tokenData.AuthMethod,
|
||||
Provider: tokenData.Provider,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
ClientID: tokenData.ClientID,
|
||||
ClientSecret: tokenData.ClientSecret,
|
||||
Region: tokenData.Region,
|
||||
StartURL: tokenData.StartURL,
|
||||
Email: tokenData.Email,
|
||||
}
|
||||
|
||||
if err := storage.SaveTokenToFile(authFilePath); err != nil {
|
||||
log.Errorf("OAuth Web: failed to save token to file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("OAuth Web: token saved to %s", authFilePath)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) ssoClient(session *webAuthSession) *SSOOIDCClient {
|
||||
return session.ssoClient
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleCallback(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
errParam := c.Query("error")
|
||||
|
||||
if errParam != "" {
|
||||
h.renderError(c, errParam)
|
||||
return
|
||||
}
|
||||
|
||||
if stateID == "" {
|
||||
h.renderError(c, "Missing state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
h.renderError(c, "Invalid or expired session")
|
||||
return
|
||||
}
|
||||
|
||||
if session.status == statusSuccess {
|
||||
h.renderSuccess(c, session)
|
||||
} else if session.status == statusFailed {
|
||||
h.renderError(c, session.error)
|
||||
} else {
|
||||
c.Redirect(http.StatusFound, "/v0/oauth/kiro/start")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleSocialCallback(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
code := c.Query("code")
|
||||
errParam := c.Query("error")
|
||||
|
||||
if errParam != "" {
|
||||
h.renderError(c, errParam)
|
||||
return
|
||||
}
|
||||
|
||||
if stateID == "" {
|
||||
h.renderError(c, "Missing state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
if code == "" {
|
||||
h.renderError(c, "Missing authorization code")
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
h.renderError(c, "Invalid or expired session")
|
||||
return
|
||||
}
|
||||
|
||||
if session.authMethod != "google" && session.authMethod != "github" {
|
||||
h.renderError(c, "Invalid session type for social callback")
|
||||
return
|
||||
}
|
||||
|
||||
socialClient := NewSocialAuthClient(h.cfg)
|
||||
redirectURI := h.getSocialCallbackURL(c)
|
||||
|
||||
tokenReq := &CreateTokenRequest{
|
||||
Code: code,
|
||||
CodeVerifier: session.codeVerifier,
|
||||
RedirectURI: redirectURI,
|
||||
}
|
||||
|
||||
tokenResp, err := socialClient.CreateToken(c.Request.Context(), tokenReq)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: social token exchange failed: %v", err)
|
||||
h.mu.Lock()
|
||||
session.status = statusFailed
|
||||
session.error = fmt.Sprintf("Token exchange failed: %v", err)
|
||||
session.completedAt = time.Now()
|
||||
h.mu.Unlock()
|
||||
h.renderError(c, session.error)
|
||||
return
|
||||
}
|
||||
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
|
||||
var provider string
|
||||
if session.authMethod == "google" {
|
||||
provider = string(ProviderGoogle)
|
||||
} else {
|
||||
provider = string(ProviderGitHub)
|
||||
}
|
||||
|
||||
tokenData := &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: tokenResp.ProfileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: session.authMethod,
|
||||
Provider: provider,
|
||||
Email: email,
|
||||
Region: "us-east-1",
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusSuccess
|
||||
session.completedAt = time.Now()
|
||||
session.expiresAt = expiresAt
|
||||
session.tokenData = tokenData
|
||||
h.mu.Unlock()
|
||||
|
||||
if session.cancelFunc != nil {
|
||||
session.cancelFunc()
|
||||
}
|
||||
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
|
||||
// Save token to file
|
||||
h.saveTokenToFile(tokenData)
|
||||
|
||||
log.Infof("OAuth Web: social authentication successful for %s via %s", email, provider)
|
||||
h.renderSuccess(c, session)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleStatus(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
if stateID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing state parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"status": string(session.status),
|
||||
}
|
||||
|
||||
switch session.status {
|
||||
case statusPending:
|
||||
elapsed := time.Since(session.startedAt).Seconds()
|
||||
remaining := float64(session.expiresIn) - elapsed
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
response["remaining_seconds"] = int(remaining)
|
||||
case statusSuccess:
|
||||
response["completed_at"] = session.completedAt.Format(time.RFC3339)
|
||||
response["expires_at"] = session.expiresAt.Format(time.RFC3339)
|
||||
case statusFailed:
|
||||
response["error"] = session.error
|
||||
response["failed_at"] = session.completedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderStartPage(c *gin.Context, session *webAuthSession) {
|
||||
tmpl, err := template.New("start").Parse(oauthWebStartPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"AuthURL": session.authURL,
|
||||
"UserCode": session.userCode,
|
||||
"ExpiresIn": session.expiresIn,
|
||||
"StateID": session.stateID,
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderSelectPage(c *gin.Context) {
|
||||
tmpl, err := template.New("select").Parse(oauthWebSelectPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse select template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, nil); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render select template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderError(c *gin.Context, errMsg string) {
|
||||
tmpl, err := template.New("error").Parse(oauthWebErrorPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse error template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"Error": errMsg,
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.Status(http.StatusBadRequest)
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render error template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderSuccess(c *gin.Context, session *webAuthSession) {
|
||||
tmpl, err := template.New("success").Parse(oauthWebSuccessPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse success template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"ExpiresAt": session.expiresAt.Format(time.RFC3339),
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render success template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) CleanupExpiredSessions() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for id, session := range h.sessions {
|
||||
if session.status != statusPending && now.Sub(session.completedAt) > 30*time.Minute {
|
||||
delete(h.sessions, id)
|
||||
} else if session.status == statusPending && now.Sub(session.startedAt) > defaultSessionExpiry {
|
||||
session.cancelFunc()
|
||||
delete(h.sessions, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) GetSession(stateID string) (*webAuthSession, bool) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
session, exists := h.sessions[stateID]
|
||||
return session, exists
|
||||
}
|
||||
|
||||
// ImportTokenRequest represents the request body for token import
|
||||
type ImportTokenRequest struct {
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// handleImportToken handles manual refresh token import from Kiro IDE
|
||||
func (h *OAuthWebHandler) handleImportToken(c *gin.Context) {
|
||||
var req ImportTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": "Invalid request body",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
refreshToken := strings.TrimSpace(req.RefreshToken)
|
||||
if refreshToken == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": "Refresh token is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate token format
|
||||
if !strings.HasPrefix(refreshToken, "aorAAAAAG") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": "Invalid token format. Token should start with aorAAAAAG...",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create social auth client to refresh and validate the token
|
||||
socialClient := NewSocialAuthClient(h.cfg)
|
||||
|
||||
// Refresh the token to validate it and get access token
|
||||
tokenData, err := socialClient.RefreshSocialToken(c.Request.Context(), refreshToken)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: token refresh failed during import: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("Token validation failed: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Set the original refresh token (the refreshed one might be empty)
|
||||
if tokenData.RefreshToken == "" {
|
||||
tokenData.RefreshToken = refreshToken
|
||||
}
|
||||
tokenData.AuthMethod = "social"
|
||||
tokenData.Provider = "imported"
|
||||
|
||||
// Notify callback if set
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
|
||||
// Save token to file
|
||||
h.saveTokenToFile(tokenData)
|
||||
|
||||
// Generate filename for response
|
||||
fileName := fmt.Sprintf("kiro-%s.json", tokenData.AuthMethod)
|
||||
if tokenData.Email != "" {
|
||||
sanitizedEmail := strings.ReplaceAll(tokenData.Email, "@", "-")
|
||||
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, ".", "-")
|
||||
fileName = fmt.Sprintf("kiro-%s-%s.json", tokenData.AuthMethod, sanitizedEmail)
|
||||
}
|
||||
|
||||
log.Infof("OAuth Web: token imported successfully")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Token imported successfully",
|
||||
"fileName": fileName,
|
||||
})
|
||||
}
|
||||
|
||||
// handleManualRefresh handles manual token refresh requests from the web UI.
|
||||
// This allows users to trigger a token refresh when needed, without waiting
|
||||
// for the automatic 30-second check and 20-minute-before-expiry refresh cycle.
|
||||
// Uses the same refresh logic as kiro_executor.Refresh for consistency.
|
||||
func (h *OAuthWebHandler) handleManualRefresh(c *gin.Context) {
|
||||
authDir := ""
|
||||
if h.cfg != nil && h.cfg.AuthDir != "" {
|
||||
var err error
|
||||
authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if authDir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"error": "Failed to get home directory",
|
||||
})
|
||||
return
|
||||
}
|
||||
authDir = filepath.Join(home, ".cli-proxy-api")
|
||||
}
|
||||
|
||||
// Find all kiro token files in the auth directory
|
||||
files, err := os.ReadDir(authDir)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("Failed to read auth directory: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var refreshedCount int
|
||||
var errors []string
|
||||
|
||||
for _, file := range files {
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := file.Name()
|
||||
if !strings.HasPrefix(name, "kiro-") || !strings.HasSuffix(name, ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
filePath := filepath.Join(authDir, name)
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: read error - %v", name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
var storage KiroTokenStorage
|
||||
if err := json.Unmarshal(data, &storage); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: parse error - %v", name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if storage.RefreshToken == "" {
|
||||
errors = append(errors, fmt.Sprintf("%s: no refresh token", name))
|
||||
continue
|
||||
}
|
||||
|
||||
// Refresh token using the same logic as kiro_executor.Refresh
|
||||
tokenData, err := h.refreshTokenData(c.Request.Context(), &storage)
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: refresh failed - %v", name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Update storage with new token data
|
||||
storage.AccessToken = tokenData.AccessToken
|
||||
if tokenData.RefreshToken != "" {
|
||||
storage.RefreshToken = tokenData.RefreshToken
|
||||
}
|
||||
storage.ExpiresAt = tokenData.ExpiresAt
|
||||
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||
if tokenData.ProfileArn != "" {
|
||||
storage.ProfileArn = tokenData.ProfileArn
|
||||
}
|
||||
|
||||
// Write updated token back to file
|
||||
updatedData, err := json.MarshalIndent(storage, "", " ")
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: marshal error - %v", name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
tmpFile := filePath + ".tmp"
|
||||
if err := os.WriteFile(tmpFile, updatedData, 0600); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: write error - %v", name, err))
|
||||
continue
|
||||
}
|
||||
if err := os.Rename(tmpFile, filePath); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: rename error - %v", name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("OAuth Web: manually refreshed token in %s, expires at %s", name, tokenData.ExpiresAt)
|
||||
refreshedCount++
|
||||
|
||||
// Notify callback if set
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
}
|
||||
|
||||
if refreshedCount == 0 && len(errors) > 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("All refresh attempts failed: %v", errors),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("Refreshed %d token(s)", refreshedCount),
|
||||
"refreshedCount": refreshedCount,
|
||||
}
|
||||
if len(errors) > 0 {
|
||||
response["warnings"] = errors
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// refreshTokenData refreshes a token using the appropriate method based on auth type.
|
||||
// This mirrors the logic in kiro_executor.Refresh for consistency.
|
||||
func (h *OAuthWebHandler) refreshTokenData(ctx context.Context, storage *KiroTokenStorage) (*KiroTokenData, error) {
|
||||
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||
|
||||
switch {
|
||||
case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "idc" && storage.Region != "":
|
||||
// IDC refresh with region-specific endpoint
|
||||
log.Debugf("OAuth Web: using SSO OIDC refresh for IDC (region=%s)", storage.Region)
|
||||
return ssoClient.RefreshTokenWithRegion(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken, storage.Region, storage.StartURL)
|
||||
|
||||
case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "builder-id":
|
||||
// Builder ID refresh with default endpoint
|
||||
log.Debugf("OAuth Web: using SSO OIDC refresh for AWS Builder ID")
|
||||
return ssoClient.RefreshToken(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken)
|
||||
|
||||
default:
|
||||
// Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub)
|
||||
log.Debugf("OAuth Web: using Kiro OAuth refresh endpoint")
|
||||
oauth := NewKiroOAuth(h.cfg)
|
||||
return oauth.RefreshToken(ctx, storage.RefreshToken)
|
||||
}
|
||||
}
|
||||
779
internal/auth/kiro/oauth_web_templates.go
Normal file
779
internal/auth/kiro/oauth_web_templates.go
Normal file
@@ -0,0 +1,779 @@
|
||||
// Package kiro provides OAuth Web authentication templates.
|
||||
package kiro
|
||||
|
||||
const (
|
||||
oauthWebStartPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>AWS SSO Authentication</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
.container {
|
||||
max-width: 500px;
|
||||
width: 100%;
|
||||
background: #fff;
|
||||
padding: 40px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
|
||||
}
|
||||
h1 {
|
||||
margin: 0 0 10px;
|
||||
color: #333;
|
||||
font-size: 24px;
|
||||
text-align: center;
|
||||
}
|
||||
.subtitle {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.step {
|
||||
background: #f8f9fa;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.step-title {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.step-number {
|
||||
width: 28px;
|
||||
height: 28px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 14px;
|
||||
margin-right: 12px;
|
||||
}
|
||||
.user-code {
|
||||
background: #e7f3ff;
|
||||
border: 2px dashed #2196F3;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
margin-top: 10px;
|
||||
}
|
||||
.user-code-label {
|
||||
font-size: 12px;
|
||||
color: #666;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 1px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.user-code-value {
|
||||
font-size: 32px;
|
||||
font-weight: bold;
|
||||
font-family: monospace;
|
||||
color: #2196F3;
|
||||
letter-spacing: 4px;
|
||||
}
|
||||
.auth-btn {
|
||||
display: block;
|
||||
width: 100%;
|
||||
padding: 15px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
text-align: center;
|
||||
text-decoration: none;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
transition: all 0.3s;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
margin-top: 20px;
|
||||
}
|
||||
.auth-btn:hover {
|
||||
background: #5568d3;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
.status {
|
||||
margin-top: 30px;
|
||||
padding: 20px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.status-pending { border-left: 4px solid #ffc107; }
|
||||
.status-success { border-left: 4px solid #28a745; }
|
||||
.status-failed { border-left: 4px solid #dc3545; }
|
||||
.spinner {
|
||||
border: 3px solid #f3f3f3;
|
||||
border-top: 3px solid #667eea;
|
||||
border-radius: 50%;
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
animation: spin 1s linear infinite;
|
||||
margin: 0 auto 15px;
|
||||
}
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
.timer {
|
||||
font-size: 24px;
|
||||
font-weight: bold;
|
||||
color: #667eea;
|
||||
margin: 10px 0;
|
||||
}
|
||||
.timer.warning { color: #ffc107; }
|
||||
.timer.danger { color: #dc3545; }
|
||||
.status-message { color: #666; line-height: 1.6; }
|
||||
.success-icon, .error-icon { font-size: 48px; margin-bottom: 15px; }
|
||||
.info-box {
|
||||
background: #e7f3ff;
|
||||
border-left: 4px solid #2196F3;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
color: #666;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 AWS SSO Authentication</h1>
|
||||
<p class="subtitle">Follow the steps below to complete authentication</p>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-title">
|
||||
<span class="step-number">1</span>
|
||||
Click the button below to open the authorization page
|
||||
</div>
|
||||
<a href="{{.AuthURL}}" target="_blank" class="auth-btn" id="authBtn">
|
||||
🚀 Open Authorization Page
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-title">
|
||||
<span class="step-number">2</span>
|
||||
Enter the verification code below
|
||||
</div>
|
||||
<div class="user-code">
|
||||
<div class="user-code-label">Verification Code</div>
|
||||
<div class="user-code-value">{{.UserCode}}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-title">
|
||||
<span class="step-number">3</span>
|
||||
Complete AWS SSO login
|
||||
</div>
|
||||
<p style="color: #666; font-size: 14px; margin-top: 10px;">
|
||||
Use your AWS SSO account to login and authorize
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div class="status status-pending" id="statusBox">
|
||||
<div class="spinner" id="spinner"></div>
|
||||
<div class="timer" id="timer">{{.ExpiresIn}}s</div>
|
||||
<div class="status-message" id="statusMessage">
|
||||
Waiting for authorization...
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="info-box">
|
||||
💡 <strong>Tip:</strong> The authorization page will open in a new tab. This page will automatically update once authorization is complete.
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let pollInterval;
|
||||
let timerInterval;
|
||||
let remainingSeconds = {{.ExpiresIn}};
|
||||
const stateID = "{{.StateID}}";
|
||||
|
||||
setTimeout(() => {
|
||||
document.getElementById('authBtn').click();
|
||||
}, 500);
|
||||
|
||||
function pollStatus() {
|
||||
fetch('/v0/oauth/kiro/status?state=' + stateID)
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log('Status:', data);
|
||||
if (data.status === 'success') {
|
||||
clearInterval(pollInterval);
|
||||
clearInterval(timerInterval);
|
||||
showSuccess(data);
|
||||
} else if (data.status === 'failed') {
|
||||
clearInterval(pollInterval);
|
||||
clearInterval(timerInterval);
|
||||
showError(data);
|
||||
} else {
|
||||
remainingSeconds = data.remaining_seconds || 0;
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Poll error:', error);
|
||||
});
|
||||
}
|
||||
|
||||
function updateTimer() {
|
||||
const timerEl = document.getElementById('timer');
|
||||
const minutes = Math.floor(remainingSeconds / 60);
|
||||
const seconds = remainingSeconds % 60;
|
||||
timerEl.textContent = minutes + ':' + seconds.toString().padStart(2, '0');
|
||||
|
||||
if (remainingSeconds < 60) {
|
||||
timerEl.className = 'timer danger';
|
||||
} else if (remainingSeconds < 180) {
|
||||
timerEl.className = 'timer warning';
|
||||
} else {
|
||||
timerEl.className = 'timer';
|
||||
}
|
||||
|
||||
remainingSeconds--;
|
||||
|
||||
if (remainingSeconds < 0) {
|
||||
clearInterval(timerInterval);
|
||||
clearInterval(pollInterval);
|
||||
showError({ error: 'Authentication timed out. Please refresh and try again.' });
|
||||
}
|
||||
}
|
||||
|
||||
function showSuccess(data) {
|
||||
const statusBox = document.getElementById('statusBox');
|
||||
statusBox.className = 'status status-success';
|
||||
statusBox.innerHTML = '<div class="success-icon">✅</div>' +
|
||||
'<div class="status-message">' +
|
||||
'<strong>Authentication Successful!</strong><br>' +
|
||||
'Token expires: ' + new Date(data.expires_at).toLocaleString() +
|
||||
'</div>';
|
||||
}
|
||||
|
||||
function showError(data) {
|
||||
const statusBox = document.getElementById('statusBox');
|
||||
statusBox.className = 'status status-failed';
|
||||
statusBox.innerHTML = '<div class="error-icon">❌</div>' +
|
||||
'<div class="status-message">' +
|
||||
'<strong>Authentication Failed</strong><br>' +
|
||||
(data.error || 'Unknown error') +
|
||||
'</div>' +
|
||||
'<button class="auth-btn" onclick="location.reload()" style="margin-top: 15px;">' +
|
||||
'🔄 Retry' +
|
||||
'</button>';
|
||||
}
|
||||
|
||||
pollInterval = setInterval(pollStatus, 3000);
|
||||
timerInterval = setInterval(updateTimer, 1000);
|
||||
pollStatus();
|
||||
</script>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
oauthWebErrorPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authentication Failed</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
max-width: 600px;
|
||||
margin: 50px auto;
|
||||
padding: 20px;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.error {
|
||||
background: #fff;
|
||||
padding: 30px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
border-left: 4px solid #dc3545;
|
||||
}
|
||||
h1 { color: #dc3545; margin-top: 0; }
|
||||
.error-message { color: #666; line-height: 1.6; }
|
||||
.retry-btn {
|
||||
display: inline-block;
|
||||
margin-top: 20px;
|
||||
padding: 10px 20px;
|
||||
background: #007bff;
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.retry-btn:hover { background: #0056b3; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="error">
|
||||
<h1>❌ Authentication Failed</h1>
|
||||
<div class="error-message">
|
||||
<p><strong>Error:</strong></p>
|
||||
<p>{{.Error}}</p>
|
||||
</div>
|
||||
<a href="/v0/oauth/kiro/start" class="retry-btn">🔄 Retry</a>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
oauthWebSuccessPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authentication Successful</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
max-width: 600px;
|
||||
margin: 50px auto;
|
||||
padding: 20px;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.success {
|
||||
background: #fff;
|
||||
padding: 30px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
border-left: 4px solid #28a745;
|
||||
text-align: center;
|
||||
}
|
||||
h1 { color: #28a745; margin-top: 0; }
|
||||
.success-message { color: #666; line-height: 1.6; }
|
||||
.icon { font-size: 48px; margin-bottom: 15px; }
|
||||
.expires { font-size: 14px; color: #999; margin-top: 15px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="success">
|
||||
<div class="icon">✅</div>
|
||||
<h1>Authentication Successful!</h1>
|
||||
<div class="success-message">
|
||||
<p>You can close this window.</p>
|
||||
</div>
|
||||
<div class="expires">Token expires: {{.ExpiresAt}}</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
oauthWebSelectPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Select Authentication Method</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
.container {
|
||||
max-width: 500px;
|
||||
width: 100%;
|
||||
background: #fff;
|
||||
padding: 40px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
|
||||
}
|
||||
h1 {
|
||||
margin: 0 0 10px;
|
||||
color: #333;
|
||||
font-size: 24px;
|
||||
text-align: center;
|
||||
}
|
||||
.subtitle {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.auth-methods {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 15px;
|
||||
}
|
||||
.auth-btn {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
width: 100%;
|
||||
padding: 15px 20px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
transition: all 0.3s;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
}
|
||||
.auth-btn:hover {
|
||||
background: #5568d3;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
.auth-btn .icon {
|
||||
font-size: 24px;
|
||||
margin-right: 15px;
|
||||
width: 32px;
|
||||
text-align: center;
|
||||
}
|
||||
.auth-btn.google { background: #4285F4; }
|
||||
.auth-btn.google:hover { background: #3367D6; }
|
||||
.auth-btn.github { background: #24292e; }
|
||||
.auth-btn.github:hover { background: #1a1e22; }
|
||||
.auth-btn.aws { background: #FF9900; }
|
||||
.auth-btn.aws:hover { background: #E68A00; }
|
||||
.auth-btn.idc { background: #232F3E; }
|
||||
.auth-btn.idc:hover { background: #1a242f; }
|
||||
.idc-form {
|
||||
background: #f8f9fa;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
margin-top: 15px;
|
||||
display: none;
|
||||
}
|
||||
.idc-form.show {
|
||||
display: block;
|
||||
}
|
||||
.form-group {
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.form-group label {
|
||||
display: block;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin-bottom: 8px;
|
||||
font-size: 14px;
|
||||
}
|
||||
.form-group input {
|
||||
width: 100%;
|
||||
padding: 12px;
|
||||
border: 2px solid #e0e0e0;
|
||||
border-radius: 6px;
|
||||
font-size: 14px;
|
||||
transition: border-color 0.3s;
|
||||
}
|
||||
.form-group input:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
.form-group .hint {
|
||||
font-size: 12px;
|
||||
color: #999;
|
||||
margin-top: 5px;
|
||||
}
|
||||
.submit-btn {
|
||||
display: block;
|
||||
width: 100%;
|
||||
padding: 15px;
|
||||
background: #232F3E;
|
||||
color: white;
|
||||
text-align: center;
|
||||
text-decoration: none;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
transition: all 0.3s;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
}
|
||||
.submit-btn:hover {
|
||||
background: #1a242f;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(35, 47, 62, 0.4);
|
||||
}
|
||||
.divider {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin: 20px 0;
|
||||
}
|
||||
.divider::before,
|
||||
.divider::after {
|
||||
content: "";
|
||||
flex: 1;
|
||||
border-bottom: 1px solid #e0e0e0;
|
||||
}
|
||||
.divider span {
|
||||
padding: 0 15px;
|
||||
color: #999;
|
||||
font-size: 14px;
|
||||
}
|
||||
.info-box {
|
||||
background: #e7f3ff;
|
||||
border-left: 4px solid #2196F3;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
color: #666;
|
||||
}
|
||||
.warning-box {
|
||||
background: #fff3cd;
|
||||
border-left: 4px solid #ffc107;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
color: #856404;
|
||||
}
|
||||
.auth-btn.manual { background: #6c757d; }
|
||||
.auth-btn.manual:hover { background: #5a6268; }
|
||||
.auth-btn.refresh { background: #17a2b8; }
|
||||
.auth-btn.refresh:hover { background: #138496; }
|
||||
.auth-btn.refresh:disabled { background: #7fb3bd; cursor: not-allowed; }
|
||||
.manual-form {
|
||||
background: #f8f9fa;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
margin-top: 15px;
|
||||
display: none;
|
||||
}
|
||||
.manual-form.show {
|
||||
display: block;
|
||||
}
|
||||
.form-group textarea {
|
||||
width: 100%;
|
||||
padding: 12px;
|
||||
border: 2px solid #e0e0e0;
|
||||
border-radius: 6px;
|
||||
font-size: 14px;
|
||||
font-family: monospace;
|
||||
transition: border-color 0.3s;
|
||||
resize: vertical;
|
||||
min-height: 80px;
|
||||
}
|
||||
.form-group textarea:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
.status-message {
|
||||
padding: 15px;
|
||||
border-radius: 6px;
|
||||
margin-top: 15px;
|
||||
display: none;
|
||||
}
|
||||
.status-message.success {
|
||||
background: #d4edda;
|
||||
color: #155724;
|
||||
display: block;
|
||||
}
|
||||
.status-message.error {
|
||||
background: #f8d7da;
|
||||
color: #721c24;
|
||||
display: block;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 Select Authentication Method</h1>
|
||||
<p class="subtitle">Choose how you want to authenticate with Kiro</p>
|
||||
|
||||
<div class="auth-methods">
|
||||
<a href="/v0/oauth/kiro/start?method=builder-id" class="auth-btn aws">
|
||||
<span class="icon">🔶</span>
|
||||
AWS Builder ID (Recommended)
|
||||
</a>
|
||||
|
||||
<button type="button" class="auth-btn idc" onclick="toggleIdcForm()">
|
||||
<span class="icon">🏢</span>
|
||||
AWS Identity Center (IDC)
|
||||
</button>
|
||||
|
||||
<div class="divider"><span>or</span></div>
|
||||
|
||||
<button type="button" class="auth-btn manual" onclick="toggleManualForm()">
|
||||
<span class="icon">📋</span>
|
||||
Import RefreshToken from Kiro IDE
|
||||
</button>
|
||||
|
||||
<button type="button" class="auth-btn refresh" onclick="manualRefresh()" id="refreshBtn">
|
||||
<span class="icon">🔄</span>
|
||||
Manual Refresh All Tokens
|
||||
</button>
|
||||
|
||||
<div class="status-message" id="refreshStatus"></div>
|
||||
</div>
|
||||
|
||||
<div class="idc-form" id="idcForm">
|
||||
<form action="/v0/oauth/kiro/start" method="get">
|
||||
<input type="hidden" name="method" value="idc">
|
||||
|
||||
<div class="form-group">
|
||||
<label for="startUrl">Start URL</label>
|
||||
<input type="url" id="startUrl" name="startUrl" placeholder="https://your-org.awsapps.com/start" required>
|
||||
<div class="hint">Your AWS Identity Center Start URL</div>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="region">Region</label>
|
||||
<input type="text" id="region" name="region" value="us-east-1" placeholder="us-east-1">
|
||||
<div class="hint">AWS Region for your Identity Center</div>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="submit-btn">
|
||||
🚀 Continue with IDC
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<div class="manual-form" id="manualForm">
|
||||
<form id="importForm" onsubmit="submitImport(event)">
|
||||
<div class="form-group">
|
||||
<label for="refreshToken">Refresh Token</label>
|
||||
<textarea id="refreshToken" name="refreshToken" placeholder="Paste your refreshToken here (starts with aorAAAAAG...)" required></textarea>
|
||||
<div class="hint">Copy from Kiro IDE: ~/.kiro/kiro-auth-token.json → refreshToken field</div>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="submit-btn" id="importBtn">
|
||||
📥 Import Token
|
||||
</button>
|
||||
|
||||
<div class="status-message" id="importStatus"></div>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<div class="warning-box">
|
||||
⚠️ <strong>Note:</strong> Google and GitHub login are not available for third-party applications due to AWS Cognito restrictions. Please use AWS Builder ID or import your token from Kiro IDE.
|
||||
</div>
|
||||
|
||||
<div class="info-box">
|
||||
💡 <strong>How to get RefreshToken:</strong><br>
|
||||
1. Open Kiro IDE and login with Google/GitHub<br>
|
||||
2. Find the token file: <code>~/.kiro/kiro-auth-token.json</code><br>
|
||||
3. Copy the <code>refreshToken</code> value and paste it above
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
function toggleIdcForm() {
|
||||
const idcForm = document.getElementById('idcForm');
|
||||
const manualForm = document.getElementById('manualForm');
|
||||
manualForm.classList.remove('show');
|
||||
idcForm.classList.toggle('show');
|
||||
if (idcForm.classList.contains('show')) {
|
||||
document.getElementById('startUrl').focus();
|
||||
}
|
||||
}
|
||||
|
||||
function toggleManualForm() {
|
||||
const idcForm = document.getElementById('idcForm');
|
||||
const manualForm = document.getElementById('manualForm');
|
||||
idcForm.classList.remove('show');
|
||||
manualForm.classList.toggle('show');
|
||||
if (manualForm.classList.contains('show')) {
|
||||
document.getElementById('refreshToken').focus();
|
||||
}
|
||||
}
|
||||
|
||||
async function submitImport(event) {
|
||||
event.preventDefault();
|
||||
const refreshToken = document.getElementById('refreshToken').value.trim();
|
||||
const statusEl = document.getElementById('importStatus');
|
||||
const btn = document.getElementById('importBtn');
|
||||
|
||||
if (!refreshToken) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = 'Please enter a refresh token';
|
||||
return;
|
||||
}
|
||||
|
||||
if (!refreshToken.startsWith('aorAAAAAG')) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = 'Invalid token format. Token should start with aorAAAAAG...';
|
||||
return;
|
||||
}
|
||||
|
||||
btn.disabled = true;
|
||||
btn.textContent = '⏳ Importing...';
|
||||
statusEl.className = 'status-message';
|
||||
statusEl.style.display = 'none';
|
||||
|
||||
try {
|
||||
const response = await fetch('/v0/oauth/kiro/import', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ refreshToken: refreshToken })
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (response.ok && data.success) {
|
||||
statusEl.className = 'status-message success';
|
||||
statusEl.textContent = '✅ Token imported successfully! File: ' + (data.fileName || 'kiro-token.json');
|
||||
} else {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = '❌ ' + (data.error || data.message || 'Import failed');
|
||||
}
|
||||
} catch (error) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = '❌ Network error: ' + error.message;
|
||||
} finally {
|
||||
btn.disabled = false;
|
||||
btn.textContent = '📥 Import Token';
|
||||
}
|
||||
}
|
||||
|
||||
async function manualRefresh() {
|
||||
const btn = document.getElementById('refreshBtn');
|
||||
const statusEl = document.getElementById('refreshStatus');
|
||||
|
||||
btn.disabled = true;
|
||||
btn.innerHTML = '<span class="icon">⏳</span> Refreshing...';
|
||||
statusEl.className = 'status-message';
|
||||
statusEl.style.display = 'none';
|
||||
|
||||
try {
|
||||
const response = await fetch('/v0/oauth/kiro/refresh', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' }
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (response.ok && data.success) {
|
||||
statusEl.className = 'status-message success';
|
||||
let msg = '✅ ' + data.message;
|
||||
if (data.warnings && data.warnings.length > 0) {
|
||||
msg += ' (Warnings: ' + data.warnings.join('; ') + ')';
|
||||
}
|
||||
statusEl.textContent = msg;
|
||||
} else {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = '❌ ' + (data.error || data.message || 'Refresh failed');
|
||||
}
|
||||
} catch (error) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = '❌ Network error: ' + error.message;
|
||||
} finally {
|
||||
btn.disabled = false;
|
||||
btn.innerHTML = '<span class="icon">🔄</span> Manual Refresh All Tokens';
|
||||
}
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>`
|
||||
)
|
||||
@@ -471,7 +471,7 @@ foreach ($port in $ports) {
|
||||
|
||||
// Create batch wrapper
|
||||
batchPath := filepath.Join(scriptDir, "kiro-oauth-handler.bat")
|
||||
batchContent := fmt.Sprintf("@echo off\npowershell -ExecutionPolicy Bypass -File \"%s\" \"%%1\"\n", scriptPath)
|
||||
batchContent := fmt.Sprintf("@echo off\npowershell -ExecutionPolicy Bypass -File \"%s\" %%1\n", scriptPath)
|
||||
|
||||
if err := os.WriteFile(batchPath, []byte(batchContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write batch wrapper: %w", err)
|
||||
|
||||
316
internal/auth/kiro/rate_limiter.go
Normal file
316
internal/auth/kiro/rate_limiter.go
Normal file
@@ -0,0 +1,316 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMinTokenInterval = 10 * time.Second
|
||||
DefaultMaxTokenInterval = 30 * time.Second
|
||||
DefaultDailyMaxRequests = 500
|
||||
DefaultJitterPercent = 0.3
|
||||
DefaultBackoffBase = 2 * time.Minute
|
||||
DefaultBackoffMax = 60 * time.Minute
|
||||
DefaultBackoffMultiplier = 2.0
|
||||
DefaultSuspendCooldown = 24 * time.Hour
|
||||
)
|
||||
|
||||
// TokenState Token 状态
|
||||
type TokenState struct {
|
||||
LastRequest time.Time
|
||||
RequestCount int
|
||||
CooldownEnd time.Time
|
||||
FailCount int
|
||||
DailyRequests int
|
||||
DailyResetTime time.Time
|
||||
IsSuspended bool
|
||||
SuspendedAt time.Time
|
||||
SuspendReason string
|
||||
}
|
||||
|
||||
// RateLimiter 频率限制器
|
||||
type RateLimiter struct {
|
||||
mu sync.RWMutex
|
||||
states map[string]*TokenState
|
||||
minTokenInterval time.Duration
|
||||
maxTokenInterval time.Duration
|
||||
dailyMaxRequests int
|
||||
jitterPercent float64
|
||||
backoffBase time.Duration
|
||||
backoffMax time.Duration
|
||||
backoffMultiplier float64
|
||||
suspendCooldown time.Duration
|
||||
rng *rand.Rand
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建默认配置的频率限制器
|
||||
func NewRateLimiter() *RateLimiter {
|
||||
return &RateLimiter{
|
||||
states: make(map[string]*TokenState),
|
||||
minTokenInterval: DefaultMinTokenInterval,
|
||||
maxTokenInterval: DefaultMaxTokenInterval,
|
||||
dailyMaxRequests: DefaultDailyMaxRequests,
|
||||
jitterPercent: DefaultJitterPercent,
|
||||
backoffBase: DefaultBackoffBase,
|
||||
backoffMax: DefaultBackoffMax,
|
||||
backoffMultiplier: DefaultBackoffMultiplier,
|
||||
suspendCooldown: DefaultSuspendCooldown,
|
||||
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimiterConfig 频率限制器配置
|
||||
type RateLimiterConfig struct {
|
||||
MinTokenInterval time.Duration
|
||||
MaxTokenInterval time.Duration
|
||||
DailyMaxRequests int
|
||||
JitterPercent float64
|
||||
BackoffBase time.Duration
|
||||
BackoffMax time.Duration
|
||||
BackoffMultiplier float64
|
||||
SuspendCooldown time.Duration
|
||||
}
|
||||
|
||||
// NewRateLimiterWithConfig 使用自定义配置创建频率限制器
|
||||
func NewRateLimiterWithConfig(cfg RateLimiterConfig) *RateLimiter {
|
||||
rl := NewRateLimiter()
|
||||
if cfg.MinTokenInterval > 0 {
|
||||
rl.minTokenInterval = cfg.MinTokenInterval
|
||||
}
|
||||
if cfg.MaxTokenInterval > 0 {
|
||||
rl.maxTokenInterval = cfg.MaxTokenInterval
|
||||
}
|
||||
if cfg.DailyMaxRequests > 0 {
|
||||
rl.dailyMaxRequests = cfg.DailyMaxRequests
|
||||
}
|
||||
if cfg.JitterPercent > 0 {
|
||||
rl.jitterPercent = cfg.JitterPercent
|
||||
}
|
||||
if cfg.BackoffBase > 0 {
|
||||
rl.backoffBase = cfg.BackoffBase
|
||||
}
|
||||
if cfg.BackoffMax > 0 {
|
||||
rl.backoffMax = cfg.BackoffMax
|
||||
}
|
||||
if cfg.BackoffMultiplier > 0 {
|
||||
rl.backoffMultiplier = cfg.BackoffMultiplier
|
||||
}
|
||||
if cfg.SuspendCooldown > 0 {
|
||||
rl.suspendCooldown = cfg.SuspendCooldown
|
||||
}
|
||||
return rl
|
||||
}
|
||||
|
||||
// getOrCreateState 获取或创建 Token 状态
|
||||
func (rl *RateLimiter) getOrCreateState(tokenKey string) *TokenState {
|
||||
state, exists := rl.states[tokenKey]
|
||||
if !exists {
|
||||
state = &TokenState{
|
||||
DailyResetTime: time.Now().Truncate(24 * time.Hour).Add(24 * time.Hour),
|
||||
}
|
||||
rl.states[tokenKey] = state
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
// resetDailyIfNeeded 如果需要则重置每日计数
|
||||
func (rl *RateLimiter) resetDailyIfNeeded(state *TokenState) {
|
||||
now := time.Now()
|
||||
if now.After(state.DailyResetTime) {
|
||||
state.DailyRequests = 0
|
||||
state.DailyResetTime = now.Truncate(24 * time.Hour).Add(24 * time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
// calculateInterval 计算带抖动的随机间隔
|
||||
func (rl *RateLimiter) calculateInterval() time.Duration {
|
||||
baseInterval := rl.minTokenInterval + time.Duration(rl.rng.Int63n(int64(rl.maxTokenInterval-rl.minTokenInterval)))
|
||||
jitter := time.Duration(float64(baseInterval) * rl.jitterPercent * (rl.rng.Float64()*2 - 1))
|
||||
return baseInterval + jitter
|
||||
}
|
||||
|
||||
// WaitForToken 等待 Token 可用(带抖动的随机间隔)
|
||||
func (rl *RateLimiter) WaitForToken(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
state := rl.getOrCreateState(tokenKey)
|
||||
rl.resetDailyIfNeeded(state)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 检查是否在冷却期
|
||||
if now.Before(state.CooldownEnd) {
|
||||
waitTime := state.CooldownEnd.Sub(now)
|
||||
rl.mu.Unlock()
|
||||
time.Sleep(waitTime)
|
||||
rl.mu.Lock()
|
||||
state = rl.getOrCreateState(tokenKey)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// 计算距离上次请求的间隔
|
||||
interval := rl.calculateInterval()
|
||||
nextAllowedTime := state.LastRequest.Add(interval)
|
||||
|
||||
if now.Before(nextAllowedTime) {
|
||||
waitTime := nextAllowedTime.Sub(now)
|
||||
rl.mu.Unlock()
|
||||
time.Sleep(waitTime)
|
||||
rl.mu.Lock()
|
||||
state = rl.getOrCreateState(tokenKey)
|
||||
}
|
||||
|
||||
state.LastRequest = time.Now()
|
||||
state.RequestCount++
|
||||
state.DailyRequests++
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
|
||||
// MarkTokenFailed 标记 Token 失败
|
||||
func (rl *RateLimiter) MarkTokenFailed(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
state := rl.getOrCreateState(tokenKey)
|
||||
state.FailCount++
|
||||
state.CooldownEnd = time.Now().Add(rl.calculateBackoff(state.FailCount))
|
||||
}
|
||||
|
||||
// MarkTokenSuccess 标记 Token 成功
|
||||
func (rl *RateLimiter) MarkTokenSuccess(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
state := rl.getOrCreateState(tokenKey)
|
||||
state.FailCount = 0
|
||||
state.CooldownEnd = time.Time{}
|
||||
}
|
||||
|
||||
// CheckAndMarkSuspended 检测暂停错误并标记
|
||||
func (rl *RateLimiter) CheckAndMarkSuspended(tokenKey string, errorMsg string) bool {
|
||||
suspendKeywords := []string{
|
||||
"suspended",
|
||||
"banned",
|
||||
"disabled",
|
||||
"account has been",
|
||||
"access denied",
|
||||
"rate limit exceeded",
|
||||
"too many requests",
|
||||
"quota exceeded",
|
||||
}
|
||||
|
||||
lowerMsg := strings.ToLower(errorMsg)
|
||||
for _, keyword := range suspendKeywords {
|
||||
if strings.Contains(lowerMsg, keyword) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
state := rl.getOrCreateState(tokenKey)
|
||||
state.IsSuspended = true
|
||||
state.SuspendedAt = time.Now()
|
||||
state.SuspendReason = errorMsg
|
||||
state.CooldownEnd = time.Now().Add(rl.suspendCooldown)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsTokenAvailable 检查 Token 是否可用
|
||||
func (rl *RateLimiter) IsTokenAvailable(tokenKey string) bool {
|
||||
rl.mu.RLock()
|
||||
defer rl.mu.RUnlock()
|
||||
|
||||
state, exists := rl.states[tokenKey]
|
||||
if !exists {
|
||||
return true
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 检查是否被暂停
|
||||
if state.IsSuspended {
|
||||
if now.After(state.SuspendedAt.Add(rl.suspendCooldown)) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否在冷却期
|
||||
if now.Before(state.CooldownEnd) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查每日请求限制
|
||||
rl.mu.RUnlock()
|
||||
rl.mu.Lock()
|
||||
rl.resetDailyIfNeeded(state)
|
||||
dailyRequests := state.DailyRequests
|
||||
dailyMax := rl.dailyMaxRequests
|
||||
rl.mu.Unlock()
|
||||
rl.mu.RLock()
|
||||
|
||||
if dailyRequests >= dailyMax {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// calculateBackoff 计算指数退避时间
|
||||
func (rl *RateLimiter) calculateBackoff(failCount int) time.Duration {
|
||||
if failCount <= 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
backoff := float64(rl.backoffBase) * math.Pow(rl.backoffMultiplier, float64(failCount-1))
|
||||
|
||||
// 添加抖动
|
||||
jitter := backoff * rl.jitterPercent * (rl.rng.Float64()*2 - 1)
|
||||
backoff += jitter
|
||||
|
||||
if time.Duration(backoff) > rl.backoffMax {
|
||||
return rl.backoffMax
|
||||
}
|
||||
return time.Duration(backoff)
|
||||
}
|
||||
|
||||
// GetTokenState 获取 Token 状态(只读)
|
||||
func (rl *RateLimiter) GetTokenState(tokenKey string) *TokenState {
|
||||
rl.mu.RLock()
|
||||
defer rl.mu.RUnlock()
|
||||
|
||||
state, exists := rl.states[tokenKey]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 返回副本以防止外部修改
|
||||
stateCopy := *state
|
||||
return &stateCopy
|
||||
}
|
||||
|
||||
// ClearTokenState 清除 Token 状态
|
||||
func (rl *RateLimiter) ClearTokenState(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
delete(rl.states, tokenKey)
|
||||
}
|
||||
|
||||
// ResetSuspension 重置暂停状态
|
||||
func (rl *RateLimiter) ResetSuspension(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
state, exists := rl.states[tokenKey]
|
||||
if exists {
|
||||
state.IsSuspended = false
|
||||
state.SuspendedAt = time.Time{}
|
||||
state.SuspendReason = ""
|
||||
state.CooldownEnd = time.Time{}
|
||||
state.FailCount = 0
|
||||
}
|
||||
}
|
||||
46
internal/auth/kiro/rate_limiter_singleton.go
Normal file
46
internal/auth/kiro/rate_limiter_singleton.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
globalRateLimiter *RateLimiter
|
||||
globalRateLimiterOnce sync.Once
|
||||
|
||||
globalCooldownManager *CooldownManager
|
||||
globalCooldownManagerOnce sync.Once
|
||||
cooldownStopCh chan struct{}
|
||||
)
|
||||
|
||||
// GetGlobalRateLimiter returns the singleton RateLimiter instance.
|
||||
func GetGlobalRateLimiter() *RateLimiter {
|
||||
globalRateLimiterOnce.Do(func() {
|
||||
globalRateLimiter = NewRateLimiter()
|
||||
log.Info("kiro: global RateLimiter initialized")
|
||||
})
|
||||
return globalRateLimiter
|
||||
}
|
||||
|
||||
// GetGlobalCooldownManager returns the singleton CooldownManager instance.
|
||||
func GetGlobalCooldownManager() *CooldownManager {
|
||||
globalCooldownManagerOnce.Do(func() {
|
||||
globalCooldownManager = NewCooldownManager()
|
||||
cooldownStopCh = make(chan struct{})
|
||||
go globalCooldownManager.StartCleanupRoutine(5*time.Minute, cooldownStopCh)
|
||||
log.Info("kiro: global CooldownManager initialized with cleanup routine")
|
||||
})
|
||||
return globalCooldownManager
|
||||
}
|
||||
|
||||
// ShutdownRateLimiters stops the cooldown cleanup routine.
|
||||
// Should be called during application shutdown.
|
||||
func ShutdownRateLimiters() {
|
||||
if cooldownStopCh != nil {
|
||||
close(cooldownStopCh)
|
||||
log.Info("kiro: rate limiter cleanup routine stopped")
|
||||
}
|
||||
}
|
||||
304
internal/auth/kiro/rate_limiter_test.go
Normal file
304
internal/auth/kiro/rate_limiter_test.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewRateLimiter(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
if rl == nil {
|
||||
t.Fatal("expected non-nil RateLimiter")
|
||||
}
|
||||
if rl.states == nil {
|
||||
t.Error("expected non-nil states map")
|
||||
}
|
||||
if rl.minTokenInterval != DefaultMinTokenInterval {
|
||||
t.Errorf("expected minTokenInterval %v, got %v", DefaultMinTokenInterval, rl.minTokenInterval)
|
||||
}
|
||||
if rl.maxTokenInterval != DefaultMaxTokenInterval {
|
||||
t.Errorf("expected maxTokenInterval %v, got %v", DefaultMaxTokenInterval, rl.maxTokenInterval)
|
||||
}
|
||||
if rl.dailyMaxRequests != DefaultDailyMaxRequests {
|
||||
t.Errorf("expected dailyMaxRequests %d, got %d", DefaultDailyMaxRequests, rl.dailyMaxRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRateLimiterWithConfig(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
MinTokenInterval: 5 * time.Second,
|
||||
MaxTokenInterval: 15 * time.Second,
|
||||
DailyMaxRequests: 100,
|
||||
JitterPercent: 0.2,
|
||||
BackoffBase: 1 * time.Minute,
|
||||
BackoffMax: 30 * time.Minute,
|
||||
BackoffMultiplier: 1.5,
|
||||
SuspendCooldown: 12 * time.Hour,
|
||||
}
|
||||
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
if rl.minTokenInterval != 5*time.Second {
|
||||
t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
|
||||
}
|
||||
if rl.maxTokenInterval != 15*time.Second {
|
||||
t.Errorf("expected maxTokenInterval 15s, got %v", rl.maxTokenInterval)
|
||||
}
|
||||
if rl.dailyMaxRequests != 100 {
|
||||
t.Errorf("expected dailyMaxRequests 100, got %d", rl.dailyMaxRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRateLimiterWithConfig_PartialConfig(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
MinTokenInterval: 5 * time.Second,
|
||||
}
|
||||
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
if rl.minTokenInterval != 5*time.Second {
|
||||
t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
|
||||
}
|
||||
if rl.maxTokenInterval != DefaultMaxTokenInterval {
|
||||
t.Errorf("expected default maxTokenInterval, got %v", rl.maxTokenInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTokenState_NonExistent(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
state := rl.GetTokenState("nonexistent")
|
||||
if state != nil {
|
||||
t.Error("expected nil state for non-existent token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTokenAvailable_NewToken(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
if !rl.IsTokenAvailable("newtoken") {
|
||||
t.Error("expected new token to be available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkTokenFailed(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.MarkTokenFailed("token1")
|
||||
|
||||
state := rl.GetTokenState("token1")
|
||||
if state == nil {
|
||||
t.Fatal("expected non-nil state")
|
||||
}
|
||||
if state.FailCount != 1 {
|
||||
t.Errorf("expected FailCount 1, got %d", state.FailCount)
|
||||
}
|
||||
if state.CooldownEnd.IsZero() {
|
||||
t.Error("expected non-zero CooldownEnd")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkTokenSuccess(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.MarkTokenFailed("token1")
|
||||
rl.MarkTokenFailed("token1")
|
||||
rl.MarkTokenSuccess("token1")
|
||||
|
||||
state := rl.GetTokenState("token1")
|
||||
if state == nil {
|
||||
t.Fatal("expected non-nil state")
|
||||
}
|
||||
if state.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0, got %d", state.FailCount)
|
||||
}
|
||||
if !state.CooldownEnd.IsZero() {
|
||||
t.Error("expected zero CooldownEnd after success")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckAndMarkSuspended_Suspended(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
|
||||
testCases := []string{
|
||||
"Account has been suspended",
|
||||
"You are banned from this service",
|
||||
"Account disabled",
|
||||
"Access denied permanently",
|
||||
"Rate limit exceeded",
|
||||
"Too many requests",
|
||||
"Quota exceeded for today",
|
||||
}
|
||||
|
||||
for i, msg := range testCases {
|
||||
tokenKey := "token" + string(rune('a'+i))
|
||||
if !rl.CheckAndMarkSuspended(tokenKey, msg) {
|
||||
t.Errorf("expected suspension detected for: %s", msg)
|
||||
}
|
||||
state := rl.GetTokenState(tokenKey)
|
||||
if !state.IsSuspended {
|
||||
t.Errorf("expected IsSuspended true for: %s", msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckAndMarkSuspended_NotSuspended(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
|
||||
normalErrors := []string{
|
||||
"connection timeout",
|
||||
"internal server error",
|
||||
"bad request",
|
||||
"invalid token format",
|
||||
}
|
||||
|
||||
for i, msg := range normalErrors {
|
||||
tokenKey := "token" + string(rune('a'+i))
|
||||
if rl.CheckAndMarkSuspended(tokenKey, msg) {
|
||||
t.Errorf("unexpected suspension for: %s", msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTokenAvailable_Suspended(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.CheckAndMarkSuspended("token1", "Account suspended")
|
||||
|
||||
if rl.IsTokenAvailable("token1") {
|
||||
t.Error("expected suspended token to be unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearTokenState(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.MarkTokenFailed("token1")
|
||||
rl.ClearTokenState("token1")
|
||||
|
||||
state := rl.GetTokenState("token1")
|
||||
if state != nil {
|
||||
t.Error("expected nil state after clear")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetSuspension(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.CheckAndMarkSuspended("token1", "Account suspended")
|
||||
rl.ResetSuspension("token1")
|
||||
|
||||
state := rl.GetTokenState("token1")
|
||||
if state.IsSuspended {
|
||||
t.Error("expected IsSuspended false after reset")
|
||||
}
|
||||
if state.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0, got %d", state.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetSuspension_NonExistent(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.ResetSuspension("nonexistent")
|
||||
}
|
||||
|
||||
func TestCalculateBackoff_ZeroFailCount(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
backoff := rl.calculateBackoff(0)
|
||||
if backoff != 0 {
|
||||
t.Errorf("expected 0 backoff for 0 fails, got %v", backoff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateBackoff_Exponential(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
BackoffBase: 1 * time.Minute,
|
||||
BackoffMax: 60 * time.Minute,
|
||||
BackoffMultiplier: 2.0,
|
||||
JitterPercent: 0.3,
|
||||
}
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
|
||||
backoff1 := rl.calculateBackoff(1)
|
||||
if backoff1 < 40*time.Second || backoff1 > 80*time.Second {
|
||||
t.Errorf("expected ~1min (with jitter) for fail 1, got %v", backoff1)
|
||||
}
|
||||
|
||||
backoff2 := rl.calculateBackoff(2)
|
||||
if backoff2 < 80*time.Second || backoff2 > 160*time.Second {
|
||||
t.Errorf("expected ~2min (with jitter) for fail 2, got %v", backoff2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateBackoff_MaxCap(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
BackoffBase: 1 * time.Minute,
|
||||
BackoffMax: 10 * time.Minute,
|
||||
BackoffMultiplier: 2.0,
|
||||
JitterPercent: 0,
|
||||
}
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
|
||||
backoff := rl.calculateBackoff(10)
|
||||
if backoff > 10*time.Minute {
|
||||
t.Errorf("expected backoff capped at 10min, got %v", backoff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTokenState_ReturnsCopy(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.MarkTokenFailed("token1")
|
||||
|
||||
state1 := rl.GetTokenState("token1")
|
||||
state1.FailCount = 999
|
||||
|
||||
state2 := rl.GetTokenState("token1")
|
||||
if state2.FailCount == 999 {
|
||||
t.Error("GetTokenState should return a copy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_ConcurrentAccess(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
const numGoroutines = 50
|
||||
const numOperations = 50
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
tokenKey := "token" + string(rune('a'+id%10))
|
||||
for j := 0; j < numOperations; j++ {
|
||||
switch j % 6 {
|
||||
case 0:
|
||||
rl.IsTokenAvailable(tokenKey)
|
||||
case 1:
|
||||
rl.MarkTokenFailed(tokenKey)
|
||||
case 2:
|
||||
rl.MarkTokenSuccess(tokenKey)
|
||||
case 3:
|
||||
rl.GetTokenState(tokenKey)
|
||||
case 4:
|
||||
rl.CheckAndMarkSuspended(tokenKey, "test error")
|
||||
case 5:
|
||||
rl.ResetSuspension(tokenKey)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestCalculateInterval_WithinRange(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
MinTokenInterval: 10 * time.Second,
|
||||
MaxTokenInterval: 30 * time.Second,
|
||||
JitterPercent: 0.3,
|
||||
}
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
|
||||
minAllowed := 7 * time.Second
|
||||
maxAllowed := 40 * time.Second
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
interval := rl.calculateInterval()
|
||||
if interval < minAllowed || interval > maxAllowed {
|
||||
t.Errorf("interval %v outside expected range [%v, %v]", interval, minAllowed, maxAllowed)
|
||||
}
|
||||
}
|
||||
}
|
||||
171
internal/auth/kiro/refresh_manager.go
Normal file
171
internal/auth/kiro/refresh_manager.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// RefreshManager 是后台刷新器的单例管理器
|
||||
type RefreshManager struct {
|
||||
mu sync.Mutex
|
||||
refresher *BackgroundRefresher
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
started bool
|
||||
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
|
||||
}
|
||||
|
||||
var (
|
||||
globalRefreshManager *RefreshManager
|
||||
managerOnce sync.Once
|
||||
)
|
||||
|
||||
// GetRefreshManager 获取全局刷新管理器实例
|
||||
func GetRefreshManager() *RefreshManager {
|
||||
managerOnce.Do(func() {
|
||||
globalRefreshManager = &RefreshManager{}
|
||||
})
|
||||
return globalRefreshManager
|
||||
}
|
||||
|
||||
// Initialize 初始化后台刷新器
|
||||
// baseDir: token 文件所在的目录
|
||||
// cfg: 应用配置
|
||||
func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.started {
|
||||
log.Debug("refresh manager: already initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
if baseDir == "" {
|
||||
log.Warn("refresh manager: base directory not provided, skipping initialization")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 创建 token 存储库
|
||||
repo := NewFileTokenRepository(baseDir)
|
||||
|
||||
// 创建后台刷新器,配置参数
|
||||
opts := []RefresherOption{
|
||||
WithInterval(time.Minute), // 每分钟检查一次
|
||||
WithBatchSize(50), // 每批最多处理 50 个 token
|
||||
WithConcurrency(10), // 最多 10 个并发刷新
|
||||
WithConfig(cfg), // 设置 OAuth 和 SSO 客户端
|
||||
}
|
||||
|
||||
// 如果已设置回调,传递给 BackgroundRefresher
|
||||
if m.onTokenRefreshed != nil {
|
||||
opts = append(opts, WithOnTokenRefreshed(m.onTokenRefreshed))
|
||||
}
|
||||
|
||||
m.refresher = NewBackgroundRefresher(repo, opts...)
|
||||
|
||||
log.Infof("refresh manager: initialized with base directory %s", baseDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start 启动后台刷新
|
||||
func (m *RefreshManager) Start() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.started {
|
||||
log.Debug("refresh manager: already started")
|
||||
return
|
||||
}
|
||||
|
||||
if m.refresher == nil {
|
||||
log.Warn("refresh manager: not initialized, cannot start")
|
||||
return
|
||||
}
|
||||
|
||||
m.ctx, m.cancel = context.WithCancel(context.Background())
|
||||
m.refresher.Start(m.ctx)
|
||||
m.started = true
|
||||
|
||||
log.Info("refresh manager: background refresh started")
|
||||
}
|
||||
|
||||
// Stop 停止后台刷新
|
||||
func (m *RefreshManager) Stop() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if !m.started {
|
||||
return
|
||||
}
|
||||
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
|
||||
if m.refresher != nil {
|
||||
m.refresher.Stop()
|
||||
}
|
||||
|
||||
m.started = false
|
||||
log.Info("refresh manager: background refresh stopped")
|
||||
}
|
||||
|
||||
// IsRunning 检查后台刷新是否正在运行
|
||||
func (m *RefreshManager) IsRunning() bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.started
|
||||
}
|
||||
|
||||
// UpdateBaseDir 更新 token 目录(用于运行时配置更改)
|
||||
func (m *RefreshManager) UpdateBaseDir(baseDir string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.refresher != nil && m.refresher.tokenRepo != nil {
|
||||
if repo, ok := m.refresher.tokenRepo.(*FileTokenRepository); ok {
|
||||
repo.SetBaseDir(baseDir)
|
||||
log.Infof("refresh manager: updated base directory to %s", baseDir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetOnTokenRefreshed 设置 token 刷新成功后的回调函数
|
||||
// 可以在任何时候调用,支持运行时更新回调
|
||||
// callback: 回调函数,接收 tokenID(文件名)和新的 token 数据
|
||||
func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.onTokenRefreshed = callback
|
||||
|
||||
// 如果 refresher 已经创建,使用并发安全的方式更新它的回调
|
||||
if m.refresher != nil {
|
||||
m.refresher.callbackMu.Lock()
|
||||
m.refresher.onTokenRefreshed = callback
|
||||
m.refresher.callbackMu.Unlock()
|
||||
}
|
||||
|
||||
log.Debug("refresh manager: token refresh callback registered")
|
||||
}
|
||||
|
||||
// InitializeAndStart 初始化并启动后台刷新(便捷方法)
|
||||
func InitializeAndStart(baseDir string, cfg *config.Config) {
|
||||
manager := GetRefreshManager()
|
||||
if err := manager.Initialize(baseDir, cfg); err != nil {
|
||||
log.Errorf("refresh manager: initialization failed: %v", err)
|
||||
return
|
||||
}
|
||||
manager.Start()
|
||||
}
|
||||
|
||||
// StopGlobalRefreshManager 停止全局刷新管理器
|
||||
func StopGlobalRefreshManager() {
|
||||
if globalRefreshManager != nil {
|
||||
globalRefreshManager.Stop()
|
||||
}
|
||||
}
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -31,6 +33,9 @@ const (
|
||||
|
||||
// OAuth timeout
|
||||
socialAuthTimeout = 10 * time.Minute
|
||||
|
||||
// Default callback port for social auth HTTP server
|
||||
socialAuthCallbackPort = 9876
|
||||
)
|
||||
|
||||
// SocialProvider represents the social login provider.
|
||||
@@ -67,6 +72,13 @@ type RefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// WebCallbackResult contains the OAuth callback result from HTTP server.
|
||||
type WebCallbackResult struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
}
|
||||
|
||||
// SocialAuthClient handles social authentication with Kiro.
|
||||
type SocialAuthClient struct {
|
||||
httpClient *http.Client
|
||||
@@ -87,6 +99,83 @@ func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient {
|
||||
}
|
||||
}
|
||||
|
||||
// startWebCallbackServer starts a local HTTP server to receive the OAuth callback.
|
||||
// This is used instead of the kiro:// protocol handler to avoid redirect_mismatch errors.
|
||||
func (c *SocialAuthClient) startWebCallbackServer(ctx context.Context, expectedState string) (string, <-chan WebCallbackResult, error) {
|
||||
// Try to find an available port - use localhost like Kiro does
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", socialAuthCallbackPort))
|
||||
if err != nil {
|
||||
// Try with dynamic port (RFC 8252 allows dynamic ports for native apps)
|
||||
log.Warnf("kiro social auth: default port %d is busy, falling back to dynamic port", socialAuthCallbackPort)
|
||||
listener, err = net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
// Use http scheme for local callback server
|
||||
redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port)
|
||||
resultChan := make(chan WebCallbackResult, 1)
|
||||
|
||||
server := &http.Server{
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
errParam := r.URL.Query().Get("error")
|
||||
|
||||
if errParam != "" {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Failed</title></head>
|
||||
<body><h1>Login Failed</h1><p>%s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
|
||||
resultChan <- WebCallbackResult{Error: errParam}
|
||||
return
|
||||
}
|
||||
|
||||
if state != expectedState {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Failed</title></head>
|
||||
<body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
|
||||
resultChan <- WebCallbackResult{Error: "state mismatch"}
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
fmt.Fprint(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Successful</title></head>
|
||||
<body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p>
|
||||
<script>window.close();</script></body></html>`)
|
||||
resultChan <- WebCallbackResult{Code: code, State: state}
|
||||
})
|
||||
|
||||
server.Handler = mux
|
||||
|
||||
go func() {
|
||||
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||
log.Debugf("kiro social auth callback server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(socialAuthTimeout):
|
||||
case <-resultChan:
|
||||
}
|
||||
_ = server.Shutdown(context.Background())
|
||||
}()
|
||||
|
||||
return redirectURI, resultChan, nil
|
||||
}
|
||||
|
||||
// generatePKCE generates PKCE code verifier and challenge.
|
||||
func generatePKCE() (verifier, challenge string, err error) {
|
||||
// Generate 32 bytes of random data for verifier
|
||||
@@ -217,10 +306,12 @@ func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: "", // Caller should preserve original provider
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoginWithSocial performs OAuth login with Google.
|
||||
// LoginWithSocial performs OAuth login with Google or GitHub.
|
||||
// Uses local HTTP callback server instead of custom protocol handler to avoid redirect_mismatch errors.
|
||||
func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) {
|
||||
providerName := string(provider)
|
||||
|
||||
@@ -228,28 +319,10 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
||||
fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName)
|
||||
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||
|
||||
// Step 1: Setup protocol handler
|
||||
// Step 1: Start local HTTP callback server (instead of kiro:// protocol handler)
|
||||
// This avoids redirect_mismatch errors with AWS Cognito
|
||||
fmt.Println("\nSetting up authentication...")
|
||||
|
||||
// Start the local callback server
|
||||
handlerPort, err := c.protocolHandler.Start(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
defer c.protocolHandler.Stop()
|
||||
|
||||
// Ensure protocol handler is installed and set as default
|
||||
if err := SetupProtocolHandlerIfNeeded(handlerPort); err != nil {
|
||||
fmt.Println("\n⚠ Protocol handler setup failed. Trying alternative method...")
|
||||
fmt.Println(" If you see a browser 'Open with' dialog, select your default browser.")
|
||||
fmt.Println(" For manual setup instructions, run: cliproxy kiro --help-protocol")
|
||||
log.Debugf("kiro: protocol handler setup error: %v", err)
|
||||
// Continue anyway - user might have set it up manually or select browser manually
|
||||
} else {
|
||||
// Force set our handler as default (prevents "Open with" dialog)
|
||||
forceDefaultProtocolHandler()
|
||||
}
|
||||
|
||||
// Step 2: Generate PKCE codes
|
||||
codeVerifier, codeChallenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
@@ -262,8 +335,15 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Build the login URL (Kiro uses GET request with query params)
|
||||
authURL := c.buildLoginURL(providerName, KiroRedirectURI, codeChallenge, state)
|
||||
// Step 4: Start local HTTP callback server
|
||||
redirectURI, resultChan, err := c.startWebCallbackServer(ctx, state)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
log.Debugf("kiro social auth: callback server started at %s", redirectURI)
|
||||
|
||||
// Step 5: Build the login URL using HTTP redirect URI
|
||||
authURL := c.buildLoginURL(providerName, redirectURI, codeChallenge, state)
|
||||
|
||||
// Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito)
|
||||
// Incognito mode enables multi-account support by bypassing cached sessions
|
||||
@@ -279,7 +359,7 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
||||
log.Debug("kiro: using incognito mode for multi-account support (default)")
|
||||
}
|
||||
|
||||
// Step 5: Open browser for user authentication
|
||||
// Step 6: Open browser for user authentication
|
||||
fmt.Println("\n════════════════════════════════════════════════════════════")
|
||||
fmt.Printf(" Opening browser for %s authentication...\n", providerName)
|
||||
fmt.Println("════════════════════════════════════════════════════════════")
|
||||
@@ -295,80 +375,78 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
||||
|
||||
fmt.Println("\n Waiting for authentication callback...")
|
||||
|
||||
// Step 6: Wait for callback
|
||||
callback, err := c.protocolHandler.WaitForCallback(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to receive callback: %w", err)
|
||||
}
|
||||
|
||||
if callback.Error != "" {
|
||||
return nil, fmt.Errorf("authentication error: %s", callback.Error)
|
||||
}
|
||||
|
||||
if callback.State != state {
|
||||
// Log state values for debugging, but don't expose in user-facing error
|
||||
log.Debugf("kiro: OAuth state mismatch - expected %s, got %s", state, callback.State)
|
||||
return nil, fmt.Errorf("OAuth state validation failed - please try again")
|
||||
}
|
||||
|
||||
if callback.Code == "" {
|
||||
return nil, fmt.Errorf("no authorization code received")
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authorization received!")
|
||||
|
||||
// Step 7: Exchange code for tokens
|
||||
fmt.Println("Exchanging code for tokens...")
|
||||
|
||||
tokenReq := &CreateTokenRequest{
|
||||
Code: callback.Code,
|
||||
CodeVerifier: codeVerifier,
|
||||
RedirectURI: KiroRedirectURI,
|
||||
}
|
||||
|
||||
tokenResp, err := c.CreateToken(ctx, tokenReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authentication successful!")
|
||||
|
||||
// Close the browser window
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser: %v", err)
|
||||
}
|
||||
|
||||
// Validate ExpiresIn - use default 1 hour if invalid
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
// Try to extract email from JWT access token first
|
||||
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
|
||||
// If no email in JWT, ask user for account label (only in interactive mode)
|
||||
if email == "" && isInteractiveTerminal() {
|
||||
fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
var err error
|
||||
email, err = reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Debugf("Failed to read account label: %v", err)
|
||||
// Step 7: Wait for callback from HTTP server
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(socialAuthTimeout):
|
||||
return nil, fmt.Errorf("authentication timed out")
|
||||
case callback := <-resultChan:
|
||||
if callback.Error != "" {
|
||||
return nil, fmt.Errorf("authentication error: %s", callback.Error)
|
||||
}
|
||||
email = strings.TrimSpace(email)
|
||||
}
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: tokenResp.ProfileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: providerName,
|
||||
Email: email, // JWT email or user-provided label
|
||||
}, nil
|
||||
// State is already validated by the callback server
|
||||
if callback.Code == "" {
|
||||
return nil, fmt.Errorf("no authorization code received")
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authorization received!")
|
||||
|
||||
// Step 8: Exchange code for tokens
|
||||
fmt.Println("Exchanging code for tokens...")
|
||||
|
||||
tokenReq := &CreateTokenRequest{
|
||||
Code: callback.Code,
|
||||
CodeVerifier: codeVerifier,
|
||||
RedirectURI: redirectURI, // Use HTTP redirect URI, not kiro:// protocol
|
||||
}
|
||||
|
||||
tokenResp, err := c.CreateToken(ctx, tokenReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authentication successful!")
|
||||
|
||||
// Close the browser window
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser: %v", err)
|
||||
}
|
||||
|
||||
// Validate ExpiresIn - use default 1 hour if invalid
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
// Try to extract email from JWT access token first
|
||||
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
|
||||
// If no email in JWT, ask user for account label (only in interactive mode)
|
||||
if email == "" && isInteractiveTerminal() {
|
||||
fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
var err error
|
||||
email, err = reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Debugf("Failed to read account label: %v", err)
|
||||
}
|
||||
email = strings.TrimSpace(email)
|
||||
}
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: tokenResp.ProfileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: providerName,
|
||||
Email: email, // JWT email or user-provided label
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// LoginWithGoogle performs OAuth login with Google.
|
||||
|
||||
@@ -2,11 +2,19 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -19,12 +27,31 @@ import (
|
||||
const (
|
||||
// AWS SSO OIDC endpoints
|
||||
ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com"
|
||||
|
||||
|
||||
// Kiro's start URL for Builder ID
|
||||
builderIDStartURL = "https://view.awsapps.com/start"
|
||||
|
||||
|
||||
// Default region for IDC
|
||||
defaultIDCRegion = "us-east-1"
|
||||
|
||||
// Polling interval
|
||||
pollInterval = 5 * time.Second
|
||||
|
||||
// Authorization code flow callback
|
||||
authCodeCallbackPath = "/oauth/callback"
|
||||
authCodeCallbackPort = 19877
|
||||
|
||||
// User-Agent to match official Kiro IDE
|
||||
kiroUserAgent = "KiroIDE"
|
||||
|
||||
// IDC token refresh headers (matching Kiro IDE behavior)
|
||||
idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE"
|
||||
)
|
||||
|
||||
// Sentinel errors for OIDC token polling
|
||||
var (
|
||||
ErrAuthorizationPending = errors.New("authorization_pending")
|
||||
ErrSlowDown = errors.New("slow_down")
|
||||
)
|
||||
|
||||
// SSOOIDCClient handles AWS SSO OIDC authentication.
|
||||
@@ -71,15 +98,447 @@ type CreateTokenResponse struct {
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// getOIDCEndpoint returns the OIDC endpoint for the given region.
|
||||
func getOIDCEndpoint(region string) string {
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
return fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
|
||||
}
|
||||
|
||||
// promptInput prompts the user for input with an optional default value.
|
||||
func promptInput(prompt, defaultValue string) string {
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
if defaultValue != "" {
|
||||
fmt.Printf("%s [%s]: ", prompt, defaultValue)
|
||||
} else {
|
||||
fmt.Printf("%s: ", prompt)
|
||||
}
|
||||
input, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Warnf("Error reading input: %v", err)
|
||||
return defaultValue
|
||||
}
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
// promptSelect prompts the user to select from options using number input.
|
||||
func promptSelect(prompt string, options []string) int {
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
|
||||
for {
|
||||
fmt.Println(prompt)
|
||||
for i, opt := range options {
|
||||
fmt.Printf(" %d) %s\n", i+1, opt)
|
||||
}
|
||||
fmt.Printf("Enter selection (1-%d): ", len(options))
|
||||
|
||||
input, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Warnf("Error reading input: %v", err)
|
||||
return 0 // Default to first option on error
|
||||
}
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
// Parse the selection
|
||||
var selection int
|
||||
if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) {
|
||||
fmt.Printf("Invalid selection '%s'. Please enter a number between 1 and %d.\n\n", input, len(options))
|
||||
continue
|
||||
}
|
||||
return selection - 1
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterClientWithRegion registers a new OIDC client with AWS using a specific region.
|
||||
func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region string) (*RegisterClientResponse, error) {
|
||||
endpoint := getOIDCEndpoint(region)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"clientName": "Kiro IDE",
|
||||
"clientType": "public",
|
||||
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
|
||||
"grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result RegisterClientResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// StartDeviceAuthorizationWithIDC starts the device authorization flow for IDC.
|
||||
func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, clientID, clientSecret, startURL, region string) (*StartDeviceAuthResponse, error) {
|
||||
endpoint := getOIDCEndpoint(region)
|
||||
|
||||
payload := map[string]string{
|
||||
"clientId": clientID,
|
||||
"clientSecret": clientSecret,
|
||||
"startUrl": startURL,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/device_authorization", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result StartDeviceAuthResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// CreateTokenWithRegion polls for the access token after user authorization using a specific region.
|
||||
func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, clientSecret, deviceCode, region string) (*CreateTokenResponse, error) {
|
||||
endpoint := getOIDCEndpoint(region)
|
||||
|
||||
payload := map[string]string{
|
||||
"clientId": clientID,
|
||||
"clientSecret": clientSecret,
|
||||
"deviceCode": deviceCode,
|
||||
"grantType": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check for pending authorization
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
var errResp struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if json.Unmarshal(respBody, &errResp) == nil {
|
||||
if errResp.Error == "authorization_pending" {
|
||||
return nil, ErrAuthorizationPending
|
||||
}
|
||||
if errResp.Error == "slow_down" {
|
||||
return nil, ErrSlowDown
|
||||
}
|
||||
}
|
||||
log.Debugf("create token failed: %s", string(respBody))
|
||||
return nil, fmt.Errorf("create token failed")
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result CreateTokenResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific region.
|
||||
func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) {
|
||||
endpoint := getOIDCEndpoint(region)
|
||||
|
||||
payload := map[string]string{
|
||||
"clientId": clientID,
|
||||
"clientSecret": clientSecret,
|
||||
"refreshToken": refreshToken,
|
||||
"grantType": "refresh_token",
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set headers matching kiro2api's IDC token refresh
|
||||
// These headers are required for successful IDC token refresh
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region))
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
|
||||
req.Header.Set("Accept", "*/*")
|
||||
req.Header.Set("Accept-Language", "*")
|
||||
req.Header.Set("sec-fetch-mode", "cors")
|
||||
req.Header.Set("User-Agent", "node")
|
||||
req.Header.Set("Accept-Encoding", "br, gzip, deflate")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Warnf("IDC token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result CreateTokenResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second)
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "idc",
|
||||
Provider: "AWS",
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
StartURL: startURL,
|
||||
Region: region,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoginWithIDC performs the full device code flow for AWS Identity Center (IDC).
|
||||
func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region string) (*KiroTokenData, error) {
|
||||
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||
fmt.Println("║ Kiro Authentication (AWS Identity Center) ║")
|
||||
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||
|
||||
// Step 1: Register client with the specified region
|
||||
fmt.Println("\nRegistering client...")
|
||||
regResp, err := c.RegisterClientWithRegion(ctx, region)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to register client: %w", err)
|
||||
}
|
||||
log.Debugf("Client registered: %s", regResp.ClientID)
|
||||
|
||||
// Step 2: Start device authorization with IDC start URL
|
||||
fmt.Println("Starting device authorization...")
|
||||
authResp, err := c.StartDeviceAuthorizationWithIDC(ctx, regResp.ClientID, regResp.ClientSecret, startURL, region)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start device auth: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Show user the verification URL
|
||||
fmt.Printf("\n")
|
||||
fmt.Println("════════════════════════════════════════════════════════════")
|
||||
fmt.Printf(" Confirm the following code in the browser:\n")
|
||||
fmt.Printf(" Code: %s\n", authResp.UserCode)
|
||||
fmt.Println("════════════════════════════════════════════════════════════")
|
||||
fmt.Printf("\n Open this URL: %s\n\n", authResp.VerificationURIComplete)
|
||||
|
||||
// Set incognito mode based on config
|
||||
if c.cfg != nil {
|
||||
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
|
||||
if !c.cfg.IncognitoBrowser {
|
||||
log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.")
|
||||
} else {
|
||||
log.Debug("kiro: using incognito mode for multi-account support")
|
||||
}
|
||||
} else {
|
||||
browser.SetIncognitoMode(true)
|
||||
log.Debug("kiro: using incognito mode for multi-account support (default)")
|
||||
}
|
||||
|
||||
// Open browser
|
||||
if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil {
|
||||
log.Warnf("Could not open browser automatically: %v", err)
|
||||
fmt.Println(" Please open the URL manually in your browser.")
|
||||
} else {
|
||||
fmt.Println(" (Browser opened automatically)")
|
||||
}
|
||||
|
||||
// Step 4: Poll for token
|
||||
fmt.Println("Waiting for authorization...")
|
||||
|
||||
interval := pollInterval
|
||||
if authResp.Interval > 0 {
|
||||
interval = time.Duration(authResp.Interval) * time.Second
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
browser.CloseBrowser()
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(interval):
|
||||
tokenResp, err := c.CreateTokenWithRegion(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode, region)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrAuthorizationPending) {
|
||||
fmt.Print(".")
|
||||
continue
|
||||
}
|
||||
if errors.Is(err, ErrSlowDown) {
|
||||
interval += 5 * time.Second
|
||||
continue
|
||||
}
|
||||
browser.CloseBrowser()
|
||||
return nil, fmt.Errorf("token creation failed: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n\n✓ Authorization successful!")
|
||||
|
||||
// Close the browser window
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser: %v", err)
|
||||
}
|
||||
|
||||
// Step 5: Get profile ARN from CodeWhisperer API
|
||||
fmt.Println("Fetching profile information...")
|
||||
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||
|
||||
// Fetch user email
|
||||
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
|
||||
if email != "" {
|
||||
fmt.Printf(" Logged in as: %s\n", email)
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: profileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "idc",
|
||||
Provider: "AWS",
|
||||
ClientID: regResp.ClientID,
|
||||
ClientSecret: regResp.ClientSecret,
|
||||
Email: email,
|
||||
StartURL: startURL,
|
||||
Region: region,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Close browser on timeout
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser on timeout: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("authorization timed out")
|
||||
}
|
||||
|
||||
// LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login.
|
||||
func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroTokenData, error) {
|
||||
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||
fmt.Println("║ Kiro Authentication (AWS) ║")
|
||||
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||
|
||||
// Prompt for login method
|
||||
options := []string{
|
||||
"Use with Builder ID (personal AWS account)",
|
||||
"Use with IDC Account (organization SSO)",
|
||||
}
|
||||
selection := promptSelect("\n? Select login method:", options)
|
||||
|
||||
if selection == 0 {
|
||||
// Builder ID flow - use existing implementation
|
||||
return c.LoginWithBuilderID(ctx)
|
||||
}
|
||||
|
||||
// IDC flow - prompt for start URL and region
|
||||
fmt.Println()
|
||||
startURL := promptInput("? Enter Start URL", "")
|
||||
if startURL == "" {
|
||||
return nil, fmt.Errorf("start URL is required for IDC login")
|
||||
}
|
||||
|
||||
region := promptInput("? Enter Region", defaultIDCRegion)
|
||||
|
||||
return c.LoginWithIDC(ctx, startURL, region)
|
||||
}
|
||||
|
||||
// RegisterClient registers a new OIDC client with AWS.
|
||||
func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) {
|
||||
// Generate unique client name for each registration to support multiple accounts
|
||||
clientName := fmt.Sprintf("CLI-Proxy-API-%d", time.Now().UnixNano())
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"clientName": clientName,
|
||||
"clientName": "Kiro IDE",
|
||||
"clientType": "public",
|
||||
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations"},
|
||||
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
|
||||
"grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
@@ -92,6 +551,7 @@ func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResp
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -135,6 +595,7 @@ func (c *SSOOIDCClient) StartDeviceAuthorization(ctx context.Context, clientID,
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -179,6 +640,7 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret,
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -198,10 +660,10 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret,
|
||||
}
|
||||
if json.Unmarshal(respBody, &errResp) == nil {
|
||||
if errResp.Error == "authorization_pending" {
|
||||
return nil, fmt.Errorf("authorization_pending")
|
||||
return nil, ErrAuthorizationPending
|
||||
}
|
||||
if errResp.Error == "slow_down" {
|
||||
return nil, fmt.Errorf("slow_down")
|
||||
return nil, ErrSlowDown
|
||||
}
|
||||
}
|
||||
log.Debugf("create token failed: %s", string(respBody))
|
||||
@@ -240,6 +702,7 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -272,6 +735,7 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret
|
||||
Provider: "AWS",
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Region: defaultIDCRegion,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -345,12 +809,11 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
|
||||
case <-time.After(interval):
|
||||
tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "authorization_pending") {
|
||||
if errors.Is(err, ErrAuthorizationPending) {
|
||||
fmt.Print(".")
|
||||
continue
|
||||
}
|
||||
if strings.Contains(errStr, "slow_down") {
|
||||
if errors.Is(err, ErrSlowDown) {
|
||||
interval += 5 * time.Second
|
||||
continue
|
||||
}
|
||||
@@ -370,8 +833,8 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
|
||||
fmt.Println("Fetching profile information...")
|
||||
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||
|
||||
// Extract email from JWT access token
|
||||
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
// Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing)
|
||||
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
|
||||
if email != "" {
|
||||
fmt.Printf(" Logged in as: %s\n", email)
|
||||
}
|
||||
@@ -388,15 +851,78 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
|
||||
ClientID: regResp.ClientID,
|
||||
ClientSecret: regResp.ClientSecret,
|
||||
Email: email,
|
||||
Region: defaultIDCRegion,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close browser on timeout for better UX
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser on timeout: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("authorization timed out")
|
||||
}
|
||||
|
||||
// FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint.
|
||||
// Falls back to JWT parsing if userinfo fails.
|
||||
func (c *SSOOIDCClient) FetchUserEmail(ctx context.Context, accessToken string) string {
|
||||
// Method 1: Try userinfo endpoint (standard OIDC)
|
||||
email := c.tryUserInfoEndpoint(ctx, accessToken)
|
||||
if email != "" {
|
||||
return email
|
||||
}
|
||||
|
||||
// Close browser on timeout for better UX
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser on timeout: %v", err)
|
||||
// Method 2: Fallback to JWT parsing
|
||||
return ExtractEmailFromJWT(accessToken)
|
||||
}
|
||||
|
||||
// tryUserInfoEndpoint attempts to get user info from AWS SSO OIDC userinfo endpoint.
|
||||
func (c *SSOOIDCClient) tryUserInfoEndpoint(ctx context.Context, accessToken string) string {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ssoOIDCEndpoint+"/userinfo", nil)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return nil, fmt.Errorf("authorization timed out")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Debugf("userinfo request failed: %v", err)
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
log.Debugf("userinfo endpoint returned status %d: %s", resp.StatusCode, string(respBody))
|
||||
return ""
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
log.Debugf("userinfo response: %s", string(respBody))
|
||||
|
||||
var userInfo struct {
|
||||
Email string `json:"email"`
|
||||
Sub string `json:"sub"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(respBody, &userInfo); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if userInfo.Email != "" {
|
||||
return userInfo.Email
|
||||
}
|
||||
if userInfo.PreferredUsername != "" && strings.Contains(userInfo.PreferredUsername, "@") {
|
||||
return userInfo.PreferredUsername
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// fetchProfileArn retrieves the profile ARN from CodeWhisperer API.
|
||||
@@ -525,3 +1051,324 @@ func (c *SSOOIDCClient) tryListCustomizations(ctx context.Context, accessToken s
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// RegisterClientForAuthCode registers a new OIDC client for authorization code flow.
|
||||
func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectURI string) (*RegisterClientResponse, error) {
|
||||
payload := map[string]interface{}{
|
||||
"clientName": "Kiro IDE",
|
||||
"clientType": "public",
|
||||
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
|
||||
"grantTypes": []string{"authorization_code", "refresh_token"},
|
||||
"redirectUris": []string{redirectURI},
|
||||
"issuerUrl": builderIDStartURL,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/client/register", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("register client for auth code failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result RegisterClientResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// AuthCodeCallbackResult contains the result from authorization code callback.
|
||||
type AuthCodeCallbackResult struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
}
|
||||
|
||||
// startAuthCodeCallbackServer starts a local HTTP server to receive the authorization code callback.
|
||||
func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expectedState string) (string, <-chan AuthCodeCallbackResult, error) {
|
||||
// Try to find an available port
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", authCodeCallbackPort))
|
||||
if err != nil {
|
||||
// Try with dynamic port
|
||||
log.Warnf("sso oidc: default port %d is busy, falling back to dynamic port", authCodeCallbackPort)
|
||||
listener, err = net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
redirectURI := fmt.Sprintf("http://127.0.0.1:%d%s", port, authCodeCallbackPath)
|
||||
resultChan := make(chan AuthCodeCallbackResult, 1)
|
||||
|
||||
server := &http.Server{
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(authCodeCallbackPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
errParam := r.URL.Query().Get("error")
|
||||
|
||||
// Send response to browser
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if errParam != "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Failed</title></head>
|
||||
<body><h1>Login Failed</h1><p>Error: %s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
|
||||
resultChan <- AuthCodeCallbackResult{Error: errParam}
|
||||
return
|
||||
}
|
||||
|
||||
if state != expectedState {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Failed</title></head>
|
||||
<body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
|
||||
resultChan <- AuthCodeCallbackResult{Error: "state mismatch"}
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprint(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Successful</title></head>
|
||||
<body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p>
|
||||
<script>window.close();</script></body></html>`)
|
||||
resultChan <- AuthCodeCallbackResult{Code: code, State: state}
|
||||
})
|
||||
|
||||
server.Handler = mux
|
||||
|
||||
go func() {
|
||||
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||
log.Debugf("auth code callback server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(10 * time.Minute):
|
||||
case <-resultChan:
|
||||
}
|
||||
_ = server.Shutdown(context.Background())
|
||||
}()
|
||||
|
||||
return redirectURI, resultChan, nil
|
||||
}
|
||||
|
||||
// generatePKCEForAuthCode generates PKCE code verifier and challenge for authorization code flow.
|
||||
func generatePKCEForAuthCode() (verifier, challenge string, err error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
|
||||
// generateStateForAuthCode generates a random state parameter.
|
||||
func generateStateForAuthCode() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// CreateTokenWithAuthCode exchanges authorization code for tokens.
|
||||
func (c *SSOOIDCClient) CreateTokenWithAuthCode(ctx context.Context, clientID, clientSecret, code, codeVerifier, redirectURI string) (*CreateTokenResponse, error) {
|
||||
payload := map[string]string{
|
||||
"clientId": clientID,
|
||||
"clientSecret": clientSecret,
|
||||
"code": code,
|
||||
"codeVerifier": codeVerifier,
|
||||
"redirectUri": redirectURI,
|
||||
"grantType": "authorization_code",
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("create token with auth code failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result CreateTokenResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// LoginWithBuilderIDAuthCode performs the authorization code flow for AWS Builder ID.
|
||||
// This provides a better UX than device code flow as it uses automatic browser callback.
|
||||
func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTokenData, error) {
|
||||
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||
fmt.Println("║ Kiro Authentication (AWS Builder ID - Auth Code) ║")
|
||||
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||
|
||||
// Step 1: Generate PKCE and state
|
||||
codeVerifier, codeChallenge, err := generatePKCEForAuthCode()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate PKCE: %w", err)
|
||||
}
|
||||
|
||||
state, err := generateStateForAuthCode()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Start callback server
|
||||
fmt.Println("\nStarting callback server...")
|
||||
redirectURI, resultChan, err := c.startAuthCodeCallbackServer(ctx, state)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
log.Debugf("Callback server started, redirect URI: %s", redirectURI)
|
||||
|
||||
// Step 3: Register client with auth code grant type
|
||||
fmt.Println("Registering client...")
|
||||
regResp, err := c.RegisterClientForAuthCode(ctx, redirectURI)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to register client: %w", err)
|
||||
}
|
||||
log.Debugf("Client registered: %s", regResp.ClientID)
|
||||
|
||||
// Step 4: Build authorization URL
|
||||
scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations"
|
||||
authURL := fmt.Sprintf("%s/authorize?response_type=code&client_id=%s&redirect_uri=%s&scopes=%s&state=%s&code_challenge=%s&code_challenge_method=S256",
|
||||
ssoOIDCEndpoint,
|
||||
regResp.ClientID,
|
||||
redirectURI,
|
||||
scopes,
|
||||
state,
|
||||
codeChallenge,
|
||||
)
|
||||
|
||||
// Step 5: Open browser
|
||||
fmt.Println("\n════════════════════════════════════════════════════════════")
|
||||
fmt.Println(" Opening browser for authentication...")
|
||||
fmt.Println("════════════════════════════════════════════════════════════")
|
||||
fmt.Printf("\n URL: %s\n\n", authURL)
|
||||
|
||||
// Set incognito mode
|
||||
if c.cfg != nil {
|
||||
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
|
||||
} else {
|
||||
browser.SetIncognitoMode(true)
|
||||
}
|
||||
|
||||
if err := browser.OpenURL(authURL); err != nil {
|
||||
log.Warnf("Could not open browser automatically: %v", err)
|
||||
fmt.Println(" ⚠ Could not open browser automatically.")
|
||||
fmt.Println(" Please open the URL above in your browser manually.")
|
||||
} else {
|
||||
fmt.Println(" (Browser opened automatically)")
|
||||
}
|
||||
|
||||
fmt.Println("\n Waiting for authorization callback...")
|
||||
|
||||
// Step 6: Wait for callback
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
browser.CloseBrowser()
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(10 * time.Minute):
|
||||
browser.CloseBrowser()
|
||||
return nil, fmt.Errorf("authorization timed out")
|
||||
case result := <-resultChan:
|
||||
if result.Error != "" {
|
||||
browser.CloseBrowser()
|
||||
return nil, fmt.Errorf("authorization failed: %s", result.Error)
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authorization received!")
|
||||
|
||||
// Close browser
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser: %v", err)
|
||||
}
|
||||
|
||||
// Step 7: Exchange code for tokens
|
||||
fmt.Println("Exchanging code for tokens...")
|
||||
tokenResp, err := c.CreateTokenWithAuthCode(ctx, regResp.ClientID, regResp.ClientSecret, result.Code, codeVerifier, redirectURI)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authentication successful!")
|
||||
|
||||
// Step 8: Get profile ARN
|
||||
fmt.Println("Fetching profile information...")
|
||||
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||
|
||||
// Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing)
|
||||
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
|
||||
if email != "" {
|
||||
fmt.Printf(" Logged in as: %s\n", email)
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: profileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "builder-id",
|
||||
Provider: "AWS",
|
||||
ClientID: regResp.ClientID,
|
||||
ClientSecret: regResp.ClientSecret,
|
||||
Email: email,
|
||||
Region: defaultIDCRegion,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
|
||||
// KiroTokenStorage holds the persistent token data for Kiro authentication.
|
||||
type KiroTokenStorage struct {
|
||||
// Type is the provider type for management UI recognition (must be "kiro")
|
||||
Type string `json:"type"`
|
||||
// AccessToken is the OAuth2 access token for API access
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is used to obtain new access tokens
|
||||
@@ -23,6 +25,16 @@ type KiroTokenStorage struct {
|
||||
Provider string `json:"provider"`
|
||||
// LastRefresh is the timestamp of the last token refresh
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
// ClientID is the OAuth client ID (required for token refresh)
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
// ClientSecret is the OAuth client secret (required for token refresh)
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
// Region is the AWS region
|
||||
Region string `json:"region,omitempty"`
|
||||
// StartURL is the AWS Identity Center start URL (for IDC auth)
|
||||
StartURL string `json:"start_url,omitempty"`
|
||||
// Email is the user's email address
|
||||
Email string `json:"email,omitempty"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile persists the token storage to the specified file path.
|
||||
@@ -68,5 +80,10 @@ func (s *KiroTokenStorage) ToTokenData() *KiroTokenData {
|
||||
ExpiresAt: s.ExpiresAt,
|
||||
AuthMethod: s.AuthMethod,
|
||||
Provider: s.Provider,
|
||||
ClientID: s.ClientID,
|
||||
ClientSecret: s.ClientSecret,
|
||||
Region: s.Region,
|
||||
StartURL: s.StartURL,
|
||||
Email: s.Email,
|
||||
}
|
||||
}
|
||||
|
||||
273
internal/auth/kiro/token_repository.go
Normal file
273
internal/auth/kiro/token_repository.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// FileTokenRepository 实现 TokenRepository 接口,基于文件系统存储
|
||||
type FileTokenRepository struct {
|
||||
mu sync.RWMutex
|
||||
baseDir string
|
||||
}
|
||||
|
||||
// NewFileTokenRepository 创建一个新的文件 token 存储库
|
||||
func NewFileTokenRepository(baseDir string) *FileTokenRepository {
|
||||
return &FileTokenRepository{
|
||||
baseDir: baseDir,
|
||||
}
|
||||
}
|
||||
|
||||
// SetBaseDir 设置基础目录
|
||||
func (r *FileTokenRepository) SetBaseDir(dir string) {
|
||||
r.mu.Lock()
|
||||
r.baseDir = strings.TrimSpace(dir)
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
// FindOldestUnverified 查找需要刷新的 token(按最后验证时间排序)
|
||||
func (r *FileTokenRepository) FindOldestUnverified(limit int) []*Token {
|
||||
r.mu.RLock()
|
||||
baseDir := r.baseDir
|
||||
r.mu.RUnlock()
|
||||
|
||||
if baseDir == "" {
|
||||
log.Debug("token repository: base directory not configured")
|
||||
return nil
|
||||
}
|
||||
|
||||
var tokens []*Token
|
||||
|
||||
err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return nil // 忽略错误,继续遍历
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if !strings.HasSuffix(strings.ToLower(d.Name()), ".json") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 只处理 kiro 相关的 token 文件
|
||||
if !strings.HasPrefix(d.Name(), "kiro-") {
|
||||
return nil
|
||||
}
|
||||
|
||||
token, err := r.readTokenFile(path)
|
||||
if err != nil {
|
||||
log.Debugf("token repository: failed to read token file %s: %v", path, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if token != nil && token.RefreshToken != "" {
|
||||
// 检查 token 是否需要刷新(过期前 5 分钟)
|
||||
if token.ExpiresAt.IsZero() || time.Until(token.ExpiresAt) < 5*time.Minute {
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Warnf("token repository: error walking directory: %v", err)
|
||||
}
|
||||
|
||||
// 按最后验证时间排序(最旧的优先)
|
||||
sort.Slice(tokens, func(i, j int) bool {
|
||||
return tokens[i].LastVerified.Before(tokens[j].LastVerified)
|
||||
})
|
||||
|
||||
// 限制返回数量
|
||||
if limit > 0 && len(tokens) > limit {
|
||||
tokens = tokens[:limit]
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// UpdateToken 更新 token 并持久化到文件
|
||||
func (r *FileTokenRepository) UpdateToken(token *Token) error {
|
||||
if token == nil {
|
||||
return fmt.Errorf("token repository: token is nil")
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
baseDir := r.baseDir
|
||||
r.mu.RUnlock()
|
||||
|
||||
if baseDir == "" {
|
||||
return fmt.Errorf("token repository: base directory not configured")
|
||||
}
|
||||
|
||||
// 构建文件路径
|
||||
filePath := filepath.Join(baseDir, token.ID)
|
||||
if !strings.HasSuffix(filePath, ".json") {
|
||||
filePath += ".json"
|
||||
}
|
||||
|
||||
// 读取现有文件内容
|
||||
existingData := make(map[string]any)
|
||||
if data, err := os.ReadFile(filePath); err == nil {
|
||||
_ = json.Unmarshal(data, &existingData)
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
existingData["access_token"] = token.AccessToken
|
||||
existingData["refresh_token"] = token.RefreshToken
|
||||
existingData["last_refresh"] = time.Now().Format(time.RFC3339)
|
||||
|
||||
if !token.ExpiresAt.IsZero() {
|
||||
existingData["expires_at"] = token.ExpiresAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// 保持原有的关键字段
|
||||
if token.ClientID != "" {
|
||||
existingData["client_id"] = token.ClientID
|
||||
}
|
||||
if token.ClientSecret != "" {
|
||||
existingData["client_secret"] = token.ClientSecret
|
||||
}
|
||||
if token.AuthMethod != "" {
|
||||
existingData["auth_method"] = token.AuthMethod
|
||||
}
|
||||
if token.Region != "" {
|
||||
existingData["region"] = token.Region
|
||||
}
|
||||
if token.StartURL != "" {
|
||||
existingData["start_url"] = token.StartURL
|
||||
}
|
||||
|
||||
// 序列化并写入文件
|
||||
raw, err := json.MarshalIndent(existingData, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("token repository: marshal failed: %w", err)
|
||||
}
|
||||
|
||||
// 原子写入:先写入临时文件,再重命名
|
||||
tmpPath := filePath + ".tmp"
|
||||
if err := os.WriteFile(tmpPath, raw, 0o600); err != nil {
|
||||
return fmt.Errorf("token repository: write temp file failed: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmpPath, filePath); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("token repository: rename failed: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("token repository: updated token %s", token.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// readTokenFile 从文件读取 token
|
||||
func (r *FileTokenRepository) readTokenFile(path string) (*Token, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var metadata map[string]any
|
||||
if err := json.Unmarshal(data, &metadata); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查是否是 kiro token
|
||||
tokenType, _ := metadata["type"].(string)
|
||||
if tokenType != "kiro" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 检查 auth_method
|
||||
authMethod, _ := metadata["auth_method"].(string)
|
||||
if authMethod != "idc" && authMethod != "builder-id" {
|
||||
return nil, nil // 只处理 IDC 和 Builder ID token
|
||||
}
|
||||
|
||||
token := &Token{
|
||||
ID: filepath.Base(path),
|
||||
AuthMethod: authMethod,
|
||||
}
|
||||
|
||||
// 解析各字段
|
||||
if v, ok := metadata["access_token"].(string); ok {
|
||||
token.AccessToken = v
|
||||
}
|
||||
if v, ok := metadata["refresh_token"].(string); ok {
|
||||
token.RefreshToken = v
|
||||
}
|
||||
if v, ok := metadata["client_id"].(string); ok {
|
||||
token.ClientID = v
|
||||
}
|
||||
if v, ok := metadata["client_secret"].(string); ok {
|
||||
token.ClientSecret = v
|
||||
}
|
||||
if v, ok := metadata["region"].(string); ok {
|
||||
token.Region = v
|
||||
}
|
||||
if v, ok := metadata["start_url"].(string); ok {
|
||||
token.StartURL = v
|
||||
}
|
||||
if v, ok := metadata["provider"].(string); ok {
|
||||
token.Provider = v
|
||||
}
|
||||
|
||||
// 解析时间字段
|
||||
if v, ok := metadata["expires_at"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, v); err == nil {
|
||||
token.ExpiresAt = t
|
||||
}
|
||||
}
|
||||
if v, ok := metadata["last_refresh"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, v); err == nil {
|
||||
token.LastVerified = t
|
||||
}
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// ListKiroTokens 列出所有 Kiro token(用于调试)
|
||||
func (r *FileTokenRepository) ListKiroTokens(ctx context.Context) ([]*Token, error) {
|
||||
r.mu.RLock()
|
||||
baseDir := r.baseDir
|
||||
r.mu.RUnlock()
|
||||
|
||||
if baseDir == "" {
|
||||
return nil, fmt.Errorf("token repository: base directory not configured")
|
||||
}
|
||||
|
||||
var tokens []*Token
|
||||
|
||||
err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return nil
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if !strings.HasPrefix(d.Name(), "kiro-") || !strings.HasSuffix(d.Name(), ".json") {
|
||||
return nil
|
||||
}
|
||||
|
||||
token, err := r.readTokenFile(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if token != nil {
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return tokens, err
|
||||
}
|
||||
243
internal/auth/kiro/usage_checker.go
Normal file
243
internal/auth/kiro/usage_checker.go
Normal file
@@ -0,0 +1,243 @@
|
||||
// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
|
||||
// This file implements usage quota checking and monitoring.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
// UsageQuotaResponse represents the API response structure for usage quota checking.
|
||||
type UsageQuotaResponse struct {
|
||||
UsageBreakdownList []UsageBreakdownExtended `json:"usageBreakdownList"`
|
||||
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
|
||||
NextDateReset float64 `json:"nextDateReset,omitempty"`
|
||||
}
|
||||
|
||||
// UsageBreakdownExtended represents detailed usage information for quota checking.
|
||||
// Note: UsageBreakdown is already defined in codewhisperer_client.go
|
||||
type UsageBreakdownExtended struct {
|
||||
ResourceType string `json:"resourceType"`
|
||||
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
|
||||
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
|
||||
FreeTrialInfo *FreeTrialInfoExtended `json:"freeTrialInfo,omitempty"`
|
||||
}
|
||||
|
||||
// FreeTrialInfoExtended represents free trial usage information.
|
||||
type FreeTrialInfoExtended struct {
|
||||
FreeTrialStatus string `json:"freeTrialStatus"`
|
||||
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
|
||||
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
|
||||
}
|
||||
|
||||
// QuotaStatus represents the quota status for a token.
|
||||
type QuotaStatus struct {
|
||||
TotalLimit float64
|
||||
CurrentUsage float64
|
||||
RemainingQuota float64
|
||||
IsExhausted bool
|
||||
ResourceType string
|
||||
NextReset time.Time
|
||||
}
|
||||
|
||||
// UsageChecker provides methods for checking token quota usage.
|
||||
type UsageChecker struct {
|
||||
httpClient *http.Client
|
||||
endpoint string
|
||||
}
|
||||
|
||||
// NewUsageChecker creates a new UsageChecker instance.
|
||||
func NewUsageChecker(cfg *config.Config) *UsageChecker {
|
||||
return &UsageChecker{
|
||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
|
||||
endpoint: awsKiroEndpoint,
|
||||
}
|
||||
}
|
||||
|
||||
// NewUsageCheckerWithClient creates a UsageChecker with a custom HTTP client.
|
||||
func NewUsageCheckerWithClient(client *http.Client) *UsageChecker {
|
||||
return &UsageChecker{
|
||||
httpClient: client,
|
||||
endpoint: awsKiroEndpoint,
|
||||
}
|
||||
}
|
||||
|
||||
// CheckUsage retrieves usage limits for the given token.
|
||||
func (c *UsageChecker) CheckUsage(ctx context.Context, tokenData *KiroTokenData) (*UsageQuotaResponse, error) {
|
||||
if tokenData == nil {
|
||||
return nil, fmt.Errorf("token data is nil")
|
||||
}
|
||||
|
||||
if tokenData.AccessToken == "" {
|
||||
return nil, fmt.Errorf("access token is empty")
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"origin": "AI_EDITOR",
|
||||
"profileArn": tokenData.ProfileArn,
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(string(jsonBody)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||
req.Header.Set("x-amz-target", targetGetUsage)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenData.AccessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result UsageQuotaResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse usage response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// CheckUsageByAccessToken retrieves usage limits using an access token and profile ARN directly.
|
||||
func (c *UsageChecker) CheckUsageByAccessToken(ctx context.Context, accessToken, profileArn string) (*UsageQuotaResponse, error) {
|
||||
tokenData := &KiroTokenData{
|
||||
AccessToken: accessToken,
|
||||
ProfileArn: profileArn,
|
||||
}
|
||||
return c.CheckUsage(ctx, tokenData)
|
||||
}
|
||||
|
||||
// GetRemainingQuota calculates the remaining quota from usage limits.
|
||||
func GetRemainingQuota(usage *UsageQuotaResponse) float64 {
|
||||
if usage == nil || len(usage.UsageBreakdownList) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var totalRemaining float64
|
||||
for _, breakdown := range usage.UsageBreakdownList {
|
||||
remaining := breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
|
||||
if remaining > 0 {
|
||||
totalRemaining += remaining
|
||||
}
|
||||
|
||||
if breakdown.FreeTrialInfo != nil {
|
||||
freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||
if freeRemaining > 0 {
|
||||
totalRemaining += freeRemaining
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return totalRemaining
|
||||
}
|
||||
|
||||
// IsQuotaExhausted checks if the quota is exhausted based on usage limits.
|
||||
func IsQuotaExhausted(usage *UsageQuotaResponse) bool {
|
||||
if usage == nil || len(usage.UsageBreakdownList) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, breakdown := range usage.UsageBreakdownList {
|
||||
if breakdown.CurrentUsageWithPrecision < breakdown.UsageLimitWithPrecision {
|
||||
return false
|
||||
}
|
||||
|
||||
if breakdown.FreeTrialInfo != nil {
|
||||
if breakdown.FreeTrialInfo.CurrentUsageWithPrecision < breakdown.FreeTrialInfo.UsageLimitWithPrecision {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GetQuotaStatus retrieves a comprehensive quota status for a token.
|
||||
func (c *UsageChecker) GetQuotaStatus(ctx context.Context, tokenData *KiroTokenData) (*QuotaStatus, error) {
|
||||
usage, err := c.CheckUsage(ctx, tokenData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
status := &QuotaStatus{
|
||||
IsExhausted: IsQuotaExhausted(usage),
|
||||
}
|
||||
|
||||
if len(usage.UsageBreakdownList) > 0 {
|
||||
breakdown := usage.UsageBreakdownList[0]
|
||||
status.TotalLimit = breakdown.UsageLimitWithPrecision
|
||||
status.CurrentUsage = breakdown.CurrentUsageWithPrecision
|
||||
status.RemainingQuota = breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
|
||||
status.ResourceType = breakdown.ResourceType
|
||||
|
||||
if breakdown.FreeTrialInfo != nil {
|
||||
status.TotalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
|
||||
status.CurrentUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||
freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||
if freeRemaining > 0 {
|
||||
status.RemainingQuota += freeRemaining
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if usage.NextDateReset > 0 {
|
||||
status.NextReset = time.Unix(int64(usage.NextDateReset/1000), 0)
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
// CalculateAvailableCount calculates the available request count based on usage limits.
|
||||
func CalculateAvailableCount(usage *UsageQuotaResponse) float64 {
|
||||
return GetRemainingQuota(usage)
|
||||
}
|
||||
|
||||
// GetUsagePercentage calculates the usage percentage.
|
||||
func GetUsagePercentage(usage *UsageQuotaResponse) float64 {
|
||||
if usage == nil || len(usage.UsageBreakdownList) == 0 {
|
||||
return 100.0
|
||||
}
|
||||
|
||||
var totalLimit, totalUsage float64
|
||||
for _, breakdown := range usage.UsageBreakdownList {
|
||||
totalLimit += breakdown.UsageLimitWithPrecision
|
||||
totalUsage += breakdown.CurrentUsageWithPrecision
|
||||
|
||||
if breakdown.FreeTrialInfo != nil {
|
||||
totalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
|
||||
totalUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||
}
|
||||
}
|
||||
|
||||
if totalLimit == 0 {
|
||||
return 100.0
|
||||
}
|
||||
|
||||
return (totalUsage / totalLimit) * 100
|
||||
}
|
||||
195
internal/cache/signature_cache.go
vendored
Normal file
195
internal/cache/signature_cache.go
vendored
Normal file
@@ -0,0 +1,195 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SignatureEntry holds a cached thinking signature with timestamp
|
||||
type SignatureEntry struct {
|
||||
Signature string
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
// SignatureCacheTTL is how long signatures are valid
|
||||
SignatureCacheTTL = 3 * time.Hour
|
||||
|
||||
// SignatureTextHashLen is the length of the hash key (16 hex chars = 64-bit key space)
|
||||
SignatureTextHashLen = 16
|
||||
|
||||
// MinValidSignatureLen is the minimum length for a signature to be considered valid
|
||||
MinValidSignatureLen = 50
|
||||
|
||||
// CacheCleanupInterval controls how often stale entries are purged
|
||||
CacheCleanupInterval = 10 * time.Minute
|
||||
)
|
||||
|
||||
// signatureCache stores signatures by model group -> textHash -> SignatureEntry
|
||||
var signatureCache sync.Map
|
||||
|
||||
// cacheCleanupOnce ensures the background cleanup goroutine starts only once
|
||||
var cacheCleanupOnce sync.Once
|
||||
|
||||
// groupCache is the inner map type
|
||||
type groupCache struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]SignatureEntry
|
||||
}
|
||||
|
||||
// hashText creates a stable, Unicode-safe key from text content
|
||||
func hashText(text string) string {
|
||||
h := sha256.Sum256([]byte(text))
|
||||
return hex.EncodeToString(h[:])[:SignatureTextHashLen]
|
||||
}
|
||||
|
||||
// getOrCreateGroupCache gets or creates a cache bucket for a model group
|
||||
func getOrCreateGroupCache(groupKey string) *groupCache {
|
||||
// Start background cleanup on first access
|
||||
cacheCleanupOnce.Do(startCacheCleanup)
|
||||
|
||||
if val, ok := signatureCache.Load(groupKey); ok {
|
||||
return val.(*groupCache)
|
||||
}
|
||||
sc := &groupCache{entries: make(map[string]SignatureEntry)}
|
||||
actual, _ := signatureCache.LoadOrStore(groupKey, sc)
|
||||
return actual.(*groupCache)
|
||||
}
|
||||
|
||||
// startCacheCleanup launches a background goroutine that periodically
|
||||
// removes caches where all entries have expired.
|
||||
func startCacheCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(CacheCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
purgeExpiredCaches()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// purgeExpiredCaches removes caches with no valid (non-expired) entries.
|
||||
func purgeExpiredCaches() {
|
||||
now := time.Now()
|
||||
signatureCache.Range(func(key, value any) bool {
|
||||
sc := value.(*groupCache)
|
||||
sc.mu.Lock()
|
||||
// Remove expired entries
|
||||
for k, entry := range sc.entries {
|
||||
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
|
||||
delete(sc.entries, k)
|
||||
}
|
||||
}
|
||||
isEmpty := len(sc.entries) == 0
|
||||
sc.mu.Unlock()
|
||||
// Remove cache bucket if empty
|
||||
if isEmpty {
|
||||
signatureCache.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// CacheSignature stores a thinking signature for a given model group and text.
|
||||
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
|
||||
func CacheSignature(modelName, text, signature string) {
|
||||
if text == "" || signature == "" {
|
||||
return
|
||||
}
|
||||
if len(signature) < MinValidSignatureLen {
|
||||
return
|
||||
}
|
||||
|
||||
groupKey := GetModelGroup(modelName)
|
||||
textHash := hashText(text)
|
||||
sc := getOrCreateGroupCache(groupKey)
|
||||
sc.mu.Lock()
|
||||
defer sc.mu.Unlock()
|
||||
|
||||
sc.entries[textHash] = SignatureEntry{
|
||||
Signature: signature,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetCachedSignature retrieves a cached signature for a given model group and text.
|
||||
// Returns empty string if not found or expired.
|
||||
func GetCachedSignature(modelName, text string) string {
|
||||
groupKey := GetModelGroup(modelName)
|
||||
|
||||
if text == "" {
|
||||
if groupKey == "gemini" {
|
||||
return "skip_thought_signature_validator"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
val, ok := signatureCache.Load(groupKey)
|
||||
if !ok {
|
||||
if groupKey == "gemini" {
|
||||
return "skip_thought_signature_validator"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
sc := val.(*groupCache)
|
||||
|
||||
textHash := hashText(text)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
sc.mu.Lock()
|
||||
entry, exists := sc.entries[textHash]
|
||||
if !exists {
|
||||
sc.mu.Unlock()
|
||||
if groupKey == "gemini" {
|
||||
return "skip_thought_signature_validator"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
|
||||
delete(sc.entries, textHash)
|
||||
sc.mu.Unlock()
|
||||
if groupKey == "gemini" {
|
||||
return "skip_thought_signature_validator"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Refresh TTL on access (sliding expiration).
|
||||
entry.Timestamp = now
|
||||
sc.entries[textHash] = entry
|
||||
sc.mu.Unlock()
|
||||
|
||||
return entry.Signature
|
||||
}
|
||||
|
||||
// ClearSignatureCache clears signature cache for a specific model group or all groups.
|
||||
func ClearSignatureCache(modelName string) {
|
||||
if modelName == "" {
|
||||
signatureCache.Range(func(key, _ any) bool {
|
||||
signatureCache.Delete(key)
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
groupKey := GetModelGroup(modelName)
|
||||
signatureCache.Delete(groupKey)
|
||||
}
|
||||
|
||||
// HasValidSignature checks if a signature is valid (non-empty and long enough)
|
||||
func HasValidSignature(modelName, signature string) bool {
|
||||
return (signature != "" && len(signature) >= MinValidSignatureLen) || (signature == "skip_thought_signature_validator" && GetModelGroup(modelName) == "gemini")
|
||||
}
|
||||
|
||||
func GetModelGroup(modelName string) string {
|
||||
if strings.Contains(modelName, "gpt") {
|
||||
return "gpt"
|
||||
} else if strings.Contains(modelName, "claude") {
|
||||
return "claude"
|
||||
} else if strings.Contains(modelName, "gemini") {
|
||||
return "gemini"
|
||||
}
|
||||
return modelName
|
||||
}
|
||||
210
internal/cache/signature_cache_test.go
vendored
Normal file
210
internal/cache/signature_cache_test.go
vendored
Normal file
@@ -0,0 +1,210 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const testModelName = "claude-sonnet-4-5"
|
||||
|
||||
func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
text := "This is some thinking text content"
|
||||
signature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
|
||||
// Store signature
|
||||
CacheSignature(testModelName, text, signature)
|
||||
|
||||
// Retrieve signature
|
||||
retrieved := GetCachedSignature(testModelName, text)
|
||||
if retrieved != signature {
|
||||
t.Errorf("Expected signature '%s', got '%s'", signature, retrieved)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_DifferentModelGroups(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
text := "Same text across models"
|
||||
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
||||
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
||||
|
||||
geminiModel := "gemini-3-pro-preview"
|
||||
CacheSignature(testModelName, text, sig1)
|
||||
CacheSignature(geminiModel, text, sig2)
|
||||
|
||||
if GetCachedSignature(testModelName, text) != sig1 {
|
||||
t.Error("Claude signature mismatch")
|
||||
}
|
||||
if GetCachedSignature(geminiModel, text) != sig2 {
|
||||
t.Error("Gemini signature mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_NotFound(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// Non-existent session
|
||||
if got := GetCachedSignature(testModelName, "some text"); got != "" {
|
||||
t.Errorf("Expected empty string for nonexistent session, got '%s'", got)
|
||||
}
|
||||
|
||||
// Existing session but different text
|
||||
CacheSignature(testModelName, "text-a", "sigA12345678901234567890123456789012345678901234567890")
|
||||
if got := GetCachedSignature(testModelName, "text-b"); got != "" {
|
||||
t.Errorf("Expected empty string for different text, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_EmptyInputs(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// All empty/invalid inputs should be no-ops
|
||||
CacheSignature(testModelName, "", "sig12345678901234567890123456789012345678901234567890")
|
||||
CacheSignature(testModelName, "text", "")
|
||||
CacheSignature(testModelName, "text", "short") // Too short
|
||||
|
||||
if got := GetCachedSignature(testModelName, "text"); got != "" {
|
||||
t.Errorf("Expected empty after invalid cache attempts, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
text := "Some text"
|
||||
shortSig := "abc123" // Less than 50 chars
|
||||
|
||||
CacheSignature(testModelName, text, shortSig)
|
||||
|
||||
if got := GetCachedSignature(testModelName, text); got != "" {
|
||||
t.Errorf("Short signature should be rejected, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearSignatureCache_ModelGroup(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
CacheSignature(testModelName, "text", sig)
|
||||
CacheSignature(testModelName, "text-2", sig)
|
||||
|
||||
ClearSignatureCache("session-1")
|
||||
|
||||
if got := GetCachedSignature(testModelName, "text"); got != sig {
|
||||
t.Error("signature should remain when clearing unknown session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearSignatureCache_AllSessions(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
CacheSignature(testModelName, "text", sig)
|
||||
CacheSignature(testModelName, "text-2", sig)
|
||||
|
||||
ClearSignatureCache("")
|
||||
|
||||
if got := GetCachedSignature(testModelName, "text"); got != "" {
|
||||
t.Error("text should be cleared")
|
||||
}
|
||||
if got := GetCachedSignature(testModelName, "text-2"); got != "" {
|
||||
t.Error("text-2 should be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasValidSignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
signature string
|
||||
expected bool
|
||||
}{
|
||||
{"valid long signature", testModelName, "abc123validSignature1234567890123456789012345678901234567890", true},
|
||||
{"exactly 50 chars", testModelName, "12345678901234567890123456789012345678901234567890", true},
|
||||
{"49 chars - invalid", testModelName, "1234567890123456789012345678901234567890123456789", false},
|
||||
{"empty string", testModelName, "", false},
|
||||
{"short signature", testModelName, "abc", false},
|
||||
{"gemini sentinel", "gemini-3-pro-preview", "skip_thought_signature_validator", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := HasValidSignature(tt.modelName, tt.signature)
|
||||
if result != tt.expected {
|
||||
t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// Different texts should produce different hashes
|
||||
text1 := "First thinking text"
|
||||
text2 := "Second thinking text"
|
||||
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
||||
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
||||
|
||||
CacheSignature(testModelName, text1, sig1)
|
||||
CacheSignature(testModelName, text2, sig2)
|
||||
|
||||
if GetCachedSignature(testModelName, text1) != sig1 {
|
||||
t.Error("text1 signature mismatch")
|
||||
}
|
||||
if GetCachedSignature(testModelName, text2) != sig2 {
|
||||
t.Error("text2 signature mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_UnicodeText(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
text := "한글 텍스트와 이모지 🎉 그리고 特殊文字"
|
||||
sig := "unicodeSig123456789012345678901234567890123456789012345"
|
||||
|
||||
CacheSignature(testModelName, text, sig)
|
||||
|
||||
if got := GetCachedSignature(testModelName, text); got != sig {
|
||||
t.Errorf("Unicode text signature retrieval failed, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_Overwrite(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
text := "Same text"
|
||||
sig1 := "firstSignature12345678901234567890123456789012345678901"
|
||||
sig2 := "secondSignature1234567890123456789012345678901234567890"
|
||||
|
||||
CacheSignature(testModelName, text, sig1)
|
||||
CacheSignature(testModelName, text, sig2) // Overwrite
|
||||
|
||||
if got := GetCachedSignature(testModelName, text); got != sig2 {
|
||||
t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got)
|
||||
}
|
||||
}
|
||||
|
||||
// Note: TTL expiration test is tricky to test without mocking time
|
||||
// We test the logic path exists but actual expiration would require time manipulation
|
||||
func TestCacheSignature_ExpirationLogic(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// This test verifies the expiration check exists
|
||||
// In a real scenario, we'd mock time.Now()
|
||||
text := "text"
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
|
||||
CacheSignature(testModelName, text, sig)
|
||||
|
||||
// Fresh entry should be retrievable
|
||||
if got := GetCachedSignature(testModelName, text); got != sig {
|
||||
t.Errorf("Fresh entry should be retrievable, got '%s'", got)
|
||||
}
|
||||
|
||||
// We can't easily test actual expiration without time mocking
|
||||
// but the logic is verified by the implementation
|
||||
_ = time.Now() // Acknowledge we're not testing time passage
|
||||
}
|
||||
@@ -24,12 +24,18 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
|
||||
|
||||
@@ -15,11 +15,17 @@ func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts)
|
||||
|
||||
@@ -20,19 +20,14 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = func(prompt string) (string, error) {
|
||||
fmt.Println()
|
||||
fmt.Println(prompt)
|
||||
var value string
|
||||
_, err := fmt.Scanln(&value)
|
||||
return value, err
|
||||
}
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)
|
||||
|
||||
@@ -116,6 +116,54 @@ func DoKiroAWSLogin(cfg *config.Config, options *LoginOptions) {
|
||||
fmt.Println("Kiro AWS authentication successful!")
|
||||
}
|
||||
|
||||
// DoKiroAWSAuthCodeLogin triggers Kiro authentication with AWS Builder ID using authorization code flow.
|
||||
// This provides a better UX than device code flow as it uses automatic browser callback.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration
|
||||
// - options: Login options including prompts
|
||||
func DoKiroAWSAuthCodeLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
// Note: Kiro defaults to incognito mode for multi-account support.
|
||||
// Users can override with --no-incognito if they want to use existing browser sessions.
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
// Use KiroAuthenticator with AWS Builder ID login (authorization code flow)
|
||||
authenticator := sdkAuth.NewKiroAuthenticator()
|
||||
record, err := authenticator.LoginWithAuthCode(context.Background(), cfg, &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("Kiro AWS authentication (auth code) failed: %v", err)
|
||||
fmt.Println("\nTroubleshooting:")
|
||||
fmt.Println("1. Make sure you have an AWS Builder ID")
|
||||
fmt.Println("2. Complete the authorization in the browser")
|
||||
fmt.Println("3. If callback fails, try: --kiro-aws-login (device code flow)")
|
||||
return
|
||||
}
|
||||
|
||||
// Save the auth record
|
||||
savedPath, err := manager.SaveAuth(record, cfg)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to save auth: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
if record != nil && record.Label != "" {
|
||||
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||
}
|
||||
fmt.Println("Kiro AWS authentication successful!")
|
||||
}
|
||||
|
||||
// DoKiroImport imports Kiro token from Kiro IDE's token file.
|
||||
// This is useful for users who have already logged in via Kiro IDE
|
||||
// and want to use the same credentials in CLI Proxy API.
|
||||
|
||||
@@ -55,11 +55,23 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
trimmedProjectID := strings.TrimSpace(projectID)
|
||||
callbackPrompt := promptFn
|
||||
if trimmedProjectID == "" {
|
||||
callbackPrompt = nil
|
||||
}
|
||||
|
||||
loginOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
ProjectID: strings.TrimSpace(projectID),
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
NoBrowser: options.NoBrowser,
|
||||
ProjectID: trimmedProjectID,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: callbackPrompt,
|
||||
}
|
||||
|
||||
authenticator := sdkAuth.NewGeminiAuthenticator()
|
||||
@@ -76,7 +88,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
}
|
||||
|
||||
geminiAuth := gemini.NewGeminiAuth()
|
||||
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, options.NoBrowser)
|
||||
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Prompt: callbackPrompt,
|
||||
})
|
||||
if errClient != nil {
|
||||
log.Errorf("Gemini authentication failed: %v", errClient)
|
||||
return
|
||||
@@ -90,12 +106,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
return
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn)
|
||||
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
|
||||
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
|
||||
if errSelection != nil {
|
||||
log.Errorf("Invalid project selection: %v", errSelection)
|
||||
@@ -107,6 +118,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
}
|
||||
|
||||
activatedProjects := make([]string, 0, len(projectSelections))
|
||||
seenProjects := make(map[string]bool)
|
||||
for _, candidateID := range projectSelections {
|
||||
log.Infof("Activating project %s", candidateID)
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
|
||||
@@ -123,6 +135,13 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
if finalID == "" {
|
||||
finalID = candidateID
|
||||
}
|
||||
|
||||
// Skip duplicates
|
||||
if seenProjects[finalID] {
|
||||
log.Infof("Project %s already activated, skipping", finalID)
|
||||
continue
|
||||
}
|
||||
seenProjects[finalID] = true
|
||||
activatedProjects = append(activatedProjects, finalID)
|
||||
}
|
||||
|
||||
@@ -250,7 +269,39 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
||||
finalProjectID := projectID
|
||||
if responseProjectID != "" {
|
||||
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
|
||||
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
|
||||
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
|
||||
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
|
||||
strings.EqualFold(tierID, "FREE") ||
|
||||
strings.EqualFold(tierID, "LEGACY")
|
||||
|
||||
if isFreeUser {
|
||||
// Interactive prompt for free users
|
||||
fmt.Printf("\nGoogle returned a different project ID:\n")
|
||||
fmt.Printf(" Requested (frontend): %s\n", projectID)
|
||||
fmt.Printf(" Returned (backend): %s\n\n", responseProjectID)
|
||||
fmt.Printf(" Backend project IDs have access to preview models (gemini-3-*).\n")
|
||||
fmt.Printf(" This is normal for free tier users.\n\n")
|
||||
fmt.Printf("Which project ID would you like to use?\n")
|
||||
fmt.Printf(" [1] Backend (recommended): %s\n", responseProjectID)
|
||||
fmt.Printf(" [2] Frontend: %s\n\n", projectID)
|
||||
fmt.Printf("Enter choice [1]: ")
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
choice, _ := reader.ReadString('\n')
|
||||
choice = strings.TrimSpace(choice)
|
||||
|
||||
if choice == "2" {
|
||||
log.Infof("Using frontend project ID: %s", projectID)
|
||||
fmt.Println(". Warning: Frontend project IDs may not have access to preview models.")
|
||||
finalProjectID = projectID
|
||||
} else {
|
||||
log.Infof("Using backend project ID: %s (recommended)", responseProjectID)
|
||||
finalProjectID = responseProjectID
|
||||
}
|
||||
} else {
|
||||
// Pro users: keep requested project ID (original behavior)
|
||||
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
|
||||
}
|
||||
} else {
|
||||
finalProjectID = responseProjectID
|
||||
}
|
||||
|
||||
@@ -19,6 +19,9 @@ type LoginOptions struct {
|
||||
// NoBrowser indicates whether to skip opening the browser automatically.
|
||||
NoBrowser bool
|
||||
|
||||
// CallbackPort overrides the local OAuth callback port when set (>0).
|
||||
CallbackPort int
|
||||
|
||||
// Prompt allows the caller to provide interactive input when needed.
|
||||
Prompt func(prompt string) (string, error)
|
||||
}
|
||||
@@ -35,12 +38,18 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||
|
||||
@@ -36,9 +36,10 @@ func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
|
||||
}
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
|
||||
|
||||
@@ -6,20 +6,23 @@ package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
|
||||
// Config represents the application's configuration, loaded from a YAML file.
|
||||
type Config struct {
|
||||
config.SDKConfig `yaml:",inline"`
|
||||
SDKConfig `yaml:",inline"`
|
||||
// Host is the network host/interface on which the API server will bind.
|
||||
// Default is empty ("") to bind all interfaces (IPv4 + IPv6). Use "127.0.0.1" or "localhost" for local-only access.
|
||||
Host string `yaml:"host" json:"-"`
|
||||
@@ -38,9 +41,16 @@ type Config struct {
|
||||
// Debug enables or disables debug-level logging and other debug features.
|
||||
Debug bool `yaml:"debug" json:"debug"`
|
||||
|
||||
// CommercialMode disables high-overhead HTTP middleware features to minimize per-request memory usage.
|
||||
CommercialMode bool `yaml:"commercial-mode" json:"commercial-mode"`
|
||||
|
||||
// LoggingToFile controls whether application logs are written to rotating files or stdout.
|
||||
LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"`
|
||||
|
||||
// LogsMaxTotalSizeMB limits the total size (in MB) of log files under the logs directory.
|
||||
// When exceeded, the oldest log files are deleted until within the limit. Set to 0 to disable.
|
||||
LogsMaxTotalSizeMB int `yaml:"logs-max-total-size-mb" json:"logs-max-total-size-mb"`
|
||||
|
||||
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
|
||||
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
|
||||
|
||||
@@ -55,9 +65,17 @@ type Config struct {
|
||||
// QuotaExceeded defines the behavior when a quota is exceeded.
|
||||
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"`
|
||||
|
||||
// Routing controls credential selection behavior.
|
||||
Routing RoutingConfig `yaml:"routing" json:"routing"`
|
||||
|
||||
// WebsocketAuth enables or disables authentication for the WebSocket API.
|
||||
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
|
||||
|
||||
// CodexInstructionsEnabled controls whether official Codex instructions are injected.
|
||||
// When false (default), CodexInstructionsForModel returns immediately without modification.
|
||||
// When true, the original instruction injection logic is used.
|
||||
CodexInstructionsEnabled bool `yaml:"codex-instructions-enabled" json:"codex-instructions-enabled"`
|
||||
|
||||
// GeminiKey defines Gemini API key configurations with optional routing overrides.
|
||||
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
|
||||
|
||||
@@ -85,8 +103,17 @@ type Config struct {
|
||||
AmpCode AmpCode `yaml:"ampcode" json:"ampcode"`
|
||||
|
||||
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
|
||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
|
||||
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
|
||||
|
||||
// OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels.
|
||||
// These aliases affect both model listing and model routing for supported channels:
|
||||
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
|
||||
//
|
||||
// NOTE: This does not apply to existing per-credential model alias features under:
|
||||
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
|
||||
OAuthModelAlias map[string][]OAuthModelAlias `yaml:"oauth-model-alias,omitempty" json:"oauth-model-alias,omitempty"`
|
||||
|
||||
// Payload defines default and override rules for provider payload parameters.
|
||||
Payload PayloadConfig `yaml:"payload" json:"payload"`
|
||||
|
||||
@@ -116,6 +143,9 @@ type RemoteManagement struct {
|
||||
SecretKey string `yaml:"secret-key"`
|
||||
// DisableControlPanel skips serving and syncing the bundled management UI when true.
|
||||
DisableControlPanel bool `yaml:"disable-control-panel"`
|
||||
// PanelGitHubRepository overrides the GitHub repository used to fetch the management panel asset.
|
||||
// Accepts either a repository URL (https://github.com/org/repo) or an API releases endpoint.
|
||||
PanelGitHubRepository string `yaml:"panel-github-repository"`
|
||||
}
|
||||
|
||||
// QuotaExceeded defines the behavior when API quota limits are exceeded.
|
||||
@@ -128,6 +158,23 @@ type QuotaExceeded struct {
|
||||
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
|
||||
}
|
||||
|
||||
// RoutingConfig configures how credentials are selected for requests.
|
||||
type RoutingConfig struct {
|
||||
// Strategy selects the credential selection strategy.
|
||||
// Supported values: "round-robin" (default), "fill-first".
|
||||
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
|
||||
}
|
||||
|
||||
// OAuthModelAlias defines a model ID alias for a specific channel.
|
||||
// It maps the upstream model name (Name) to the client-visible alias (Alias).
|
||||
// When Fork is true, the alias is added as an additional model in listings while
|
||||
// keeping the original model ID available.
|
||||
type OAuthModelAlias struct {
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
Fork bool `yaml:"fork,omitempty" json:"fork,omitempty"`
|
||||
}
|
||||
|
||||
// AmpModelMapping defines a model name mapping for Amp CLI requests.
|
||||
// When Amp requests a model that isn't available locally, this mapping
|
||||
// allows routing to an alternative model that IS available.
|
||||
@@ -138,6 +185,11 @@ type AmpModelMapping struct {
|
||||
// To is the target model name to route to (e.g., "claude-sonnet-4").
|
||||
// The target model must have available providers in the registry.
|
||||
To string `yaml:"to" json:"to"`
|
||||
|
||||
// Regex indicates whether the 'from' field should be interpreted as a regular
|
||||
// expression for matching model names. When true, this mapping is evaluated
|
||||
// after exact matches and in the order provided. Defaults to false (exact match).
|
||||
Regex bool `yaml:"regex,omitempty" json:"regex,omitempty"`
|
||||
}
|
||||
|
||||
// AmpCode groups Amp CLI integration settings including upstream routing,
|
||||
@@ -149,9 +201,14 @@ type AmpCode struct {
|
||||
// UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls.
|
||||
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
||||
|
||||
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
|
||||
// When a client authenticates with a key that matches an entry, that upstream key is used.
|
||||
// If no match is found, falls back to UpstreamAPIKey (default behavior).
|
||||
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
|
||||
|
||||
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
|
||||
// to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by
|
||||
// browser attacks and remote access to management endpoints. Default: true (recommended).
|
||||
// browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient).
|
||||
RestrictManagementToLocalhost bool `yaml:"restrict-management-to-localhost" json:"restrict-management-to-localhost"`
|
||||
|
||||
// ModelMappings defines model name mappings for Amp CLI requests.
|
||||
@@ -164,12 +221,27 @@ type AmpCode struct {
|
||||
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
|
||||
}
|
||||
|
||||
// AmpUpstreamAPIKeyEntry maps a set of client API keys to a specific upstream API key.
|
||||
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
|
||||
// is used for the upstream Amp request.
|
||||
type AmpUpstreamAPIKeyEntry struct {
|
||||
// UpstreamAPIKey is the API key to use when proxying to the Amp upstream.
|
||||
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
||||
|
||||
// APIKeys are the client API keys (from top-level api-keys) that map to this upstream key.
|
||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||
}
|
||||
|
||||
// PayloadConfig defines default and override parameter rules applied to provider payloads.
|
||||
type PayloadConfig struct {
|
||||
// Default defines rules that only set parameters when they are missing in the payload.
|
||||
Default []PayloadRule `yaml:"default" json:"default"`
|
||||
// DefaultRaw defines rules that set raw JSON values only when they are missing.
|
||||
DefaultRaw []PayloadRule `yaml:"default-raw" json:"default-raw"`
|
||||
// Override defines rules that always set parameters, overwriting any existing values.
|
||||
Override []PayloadRule `yaml:"override" json:"override"`
|
||||
// OverrideRaw defines rules that always set raw JSON values, overwriting any existing values.
|
||||
OverrideRaw []PayloadRule `yaml:"override-raw" json:"override-raw"`
|
||||
}
|
||||
|
||||
// PayloadRule describes a single rule targeting a list of models with parameter updates.
|
||||
@@ -177,6 +249,7 @@ type PayloadRule struct {
|
||||
// Models lists model entries with name pattern and protocol constraint.
|
||||
Models []PayloadModelRule `yaml:"models" json:"models"`
|
||||
// Params maps JSON paths (gjson/sjson syntax) to values written into the payload.
|
||||
// For *-raw rules, values are treated as raw JSON fragments (strings are used as-is).
|
||||
Params map[string]any `yaml:"params" json:"params"`
|
||||
}
|
||||
|
||||
@@ -188,12 +261,38 @@ type PayloadModelRule struct {
|
||||
Protocol string `yaml:"protocol" json:"protocol"`
|
||||
}
|
||||
|
||||
// CloakConfig configures request cloaking for non-Claude-Code clients.
|
||||
// Cloaking disguises API requests to appear as originating from the official Claude Code CLI.
|
||||
type CloakConfig struct {
|
||||
// Mode controls cloaking behavior: "auto" (default), "always", or "never".
|
||||
// - "auto": cloak only when client is not Claude Code (based on User-Agent)
|
||||
// - "always": always apply cloaking regardless of client
|
||||
// - "never": never apply cloaking
|
||||
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
|
||||
|
||||
// StrictMode controls how system prompts are handled when cloaking.
|
||||
// - false (default): prepend Claude Code prompt to user system messages
|
||||
// - true: strip all user system messages, keep only Claude Code prompt
|
||||
StrictMode bool `yaml:"strict-mode,omitempty" json:"strict-mode,omitempty"`
|
||||
|
||||
// SensitiveWords is a list of words to obfuscate with zero-width characters.
|
||||
// This can help bypass certain content filters.
|
||||
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeKey represents the configuration for a Claude API key,
|
||||
// including the API key itself and an optional base URL for the API endpoint.
|
||||
type ClaudeKey struct {
|
||||
// APIKey is the authentication key for accessing Claude API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Priority controls selection preference when multiple credentials match.
|
||||
// Higher values are preferred; defaults to 0.
|
||||
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Claude API endpoint.
|
||||
// If empty, the default Claude API URL will be used.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
@@ -209,8 +308,14 @@ type ClaudeKey struct {
|
||||
|
||||
// ExcludedModels lists model IDs that should be excluded for this provider.
|
||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||
|
||||
// Cloak configures request cloaking for non-Claude-Code clients.
|
||||
Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"`
|
||||
}
|
||||
|
||||
func (k ClaudeKey) GetAPIKey() string { return k.APIKey }
|
||||
func (k ClaudeKey) GetBaseURL() string { return k.BaseURL }
|
||||
|
||||
// ClaudeModel describes a mapping between an alias and the actual upstream model name.
|
||||
type ClaudeModel struct {
|
||||
// Name is the upstream model identifier used when issuing requests.
|
||||
@@ -220,12 +325,22 @@ type ClaudeModel struct {
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
func (m ClaudeModel) GetName() string { return m.Name }
|
||||
func (m ClaudeModel) GetAlias() string { return m.Alias }
|
||||
|
||||
// CodexKey represents the configuration for a Codex API key,
|
||||
// including the API key itself and an optional base URL for the API endpoint.
|
||||
type CodexKey struct {
|
||||
// APIKey is the authentication key for accessing Codex API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Priority controls selection preference when multiple credentials match.
|
||||
// Higher values are preferred; defaults to 0.
|
||||
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Codex API endpoint.
|
||||
// If empty, the default Codex API URL will be used.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
@@ -233,6 +348,9 @@ type CodexKey struct {
|
||||
// ProxyURL overrides the global proxy setting for this API key if provided.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
// Models defines upstream model names and aliases for request routing.
|
||||
Models []CodexModel `yaml:"models" json:"models"`
|
||||
|
||||
// Headers optionally adds extra HTTP headers for requests sent with this key.
|
||||
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
|
||||
|
||||
@@ -240,18 +358,43 @@ type CodexKey struct {
|
||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||
}
|
||||
|
||||
func (k CodexKey) GetAPIKey() string { return k.APIKey }
|
||||
func (k CodexKey) GetBaseURL() string { return k.BaseURL }
|
||||
|
||||
// CodexModel describes a mapping between an alias and the actual upstream model name.
|
||||
type CodexModel struct {
|
||||
// Name is the upstream model identifier used when issuing requests.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Alias is the client-facing model name that maps to Name.
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
func (m CodexModel) GetName() string { return m.Name }
|
||||
func (m CodexModel) GetAlias() string { return m.Alias }
|
||||
|
||||
// GeminiKey represents the configuration for a Gemini API key,
|
||||
// including optional overrides for upstream base URL, proxy routing, and headers.
|
||||
type GeminiKey struct {
|
||||
// APIKey is the authentication key for accessing Gemini API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Priority controls selection preference when multiple credentials match.
|
||||
// Higher values are preferred; defaults to 0.
|
||||
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL optionally overrides the Gemini API endpoint.
|
||||
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
|
||||
|
||||
// ProxyURL optionally overrides the global proxy for this API key.
|
||||
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
|
||||
|
||||
// Models defines upstream model names and aliases for request routing.
|
||||
Models []GeminiModel `yaml:"models,omitempty" json:"models,omitempty"`
|
||||
|
||||
// Headers optionally adds extra HTTP headers for requests sent with this key.
|
||||
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
|
||||
|
||||
@@ -259,6 +402,21 @@ type GeminiKey struct {
|
||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||
}
|
||||
|
||||
func (k GeminiKey) GetAPIKey() string { return k.APIKey }
|
||||
func (k GeminiKey) GetBaseURL() string { return k.BaseURL }
|
||||
|
||||
// GeminiModel describes a mapping between an alias and the actual upstream model name.
|
||||
type GeminiModel struct {
|
||||
// Name is the upstream model identifier used when issuing requests.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Alias is the client-facing model name that maps to Name.
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
func (m GeminiModel) GetName() string { return m.Name }
|
||||
func (m GeminiModel) GetAlias() string { return m.Alias }
|
||||
|
||||
// KiroKey represents the configuration for Kiro (AWS CodeWhisperer) authentication.
|
||||
type KiroKey struct {
|
||||
// TokenFile is the path to the Kiro token file (default: ~/.aws/sso/cache/kiro-auth-token.json)
|
||||
@@ -294,6 +452,13 @@ type OpenAICompatibility struct {
|
||||
// Name is the identifier for this OpenAI compatibility configuration.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Priority controls selection preference when multiple providers or credentials match.
|
||||
// Higher values are preferred; defaults to 0.
|
||||
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||
|
||||
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the external OpenAI-compatible API endpoint.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
|
||||
@@ -326,6 +491,9 @@ type OpenAICompatibilityModel struct {
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
func (m OpenAICompatibilityModel) GetName() string { return m.Name }
|
||||
func (m OpenAICompatibilityModel) GetAlias() string { return m.Alias }
|
||||
|
||||
// LoadConfig reads a YAML configuration file from the given path,
|
||||
// unmarshals it into a Config struct, applies environment variable overrides,
|
||||
// and returns it.
|
||||
@@ -344,6 +512,15 @@ func LoadConfig(configFile string) (*Config, error) {
|
||||
// If optional is true and the file is missing, it returns an empty Config.
|
||||
// If optional is true and the file is empty or invalid, it returns an empty Config.
|
||||
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Perform oauth-model-alias migration before loading config.
|
||||
// This migrates oauth-model-mappings to oauth-model-alias if needed.
|
||||
if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
|
||||
// Log warning but don't fail - config loading should still work
|
||||
fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
|
||||
} else if migrated {
|
||||
fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
|
||||
}
|
||||
|
||||
// Read the entire configuration file into memory.
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
@@ -366,10 +543,12 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Set defaults before unmarshal so that absent keys keep defaults.
|
||||
cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6)
|
||||
cfg.LoggingToFile = false
|
||||
cfg.LogsMaxTotalSizeMB = 0
|
||||
cfg.UsageStatisticsEnabled = false
|
||||
cfg.DisableCooling = false
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = true // Default to secure: only localhost access
|
||||
cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
|
||||
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
||||
if optional {
|
||||
// In cloud deploy mode, if YAML parsing fails, return empty config instead of error.
|
||||
@@ -405,6 +584,15 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
_ = SaveConfigPreserveCommentsUpdateNestedScalar(configFile, []string{"remote-management", "secret-key"}, hashed)
|
||||
}
|
||||
|
||||
cfg.RemoteManagement.PanelGitHubRepository = strings.TrimSpace(cfg.RemoteManagement.PanelGitHubRepository)
|
||||
if cfg.RemoteManagement.PanelGitHubRepository == "" {
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
}
|
||||
|
||||
if cfg.LogsMaxTotalSizeMB < 0 {
|
||||
cfg.LogsMaxTotalSizeMB = 0
|
||||
}
|
||||
|
||||
// Sync request authentication providers with inline API keys for backwards compatibility.
|
||||
syncInlineAccessProvider(&cfg)
|
||||
|
||||
@@ -429,6 +617,12 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Normalize OAuth provider model exclusion map.
|
||||
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
|
||||
|
||||
// Normalize global OAuth model name aliases.
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
// Validate raw payload rules and drop invalid entries.
|
||||
cfg.SanitizePayloadRules()
|
||||
|
||||
if cfg.legacyMigrationPending {
|
||||
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
||||
if !optional && configFile != "" {
|
||||
@@ -445,6 +639,99 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// SanitizePayloadRules validates raw JSON payload rule params and drops invalid rules.
|
||||
func (cfg *Config) SanitizePayloadRules() {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
cfg.Payload.DefaultRaw = sanitizePayloadRawRules(cfg.Payload.DefaultRaw, "default-raw")
|
||||
cfg.Payload.OverrideRaw = sanitizePayloadRawRules(cfg.Payload.OverrideRaw, "override-raw")
|
||||
}
|
||||
|
||||
func sanitizePayloadRawRules(rules []PayloadRule, section string) []PayloadRule {
|
||||
if len(rules) == 0 {
|
||||
return rules
|
||||
}
|
||||
out := make([]PayloadRule, 0, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
if len(rule.Params) == 0 {
|
||||
continue
|
||||
}
|
||||
invalid := false
|
||||
for path, value := range rule.Params {
|
||||
raw, ok := payloadRawString(value)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
trimmed := bytes.TrimSpace(raw)
|
||||
if len(trimmed) == 0 || !json.Valid(trimmed) {
|
||||
log.WithFields(log.Fields{
|
||||
"section": section,
|
||||
"rule_index": i + 1,
|
||||
"param": path,
|
||||
}).Warn("payload rule dropped: invalid raw JSON")
|
||||
invalid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if invalid {
|
||||
continue
|
||||
}
|
||||
out = append(out, rule)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func payloadRawString(value any) ([]byte, bool) {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return []byte(typed), true
|
||||
case []byte:
|
||||
return typed, true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
|
||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
||||
func (cfg *Config) SanitizeOAuthModelAlias() {
|
||||
if cfg == nil || len(cfg.OAuthModelAlias) == 0 {
|
||||
return
|
||||
}
|
||||
out := make(map[string][]OAuthModelAlias, len(cfg.OAuthModelAlias))
|
||||
for rawChannel, aliases := range cfg.OAuthModelAlias {
|
||||
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||
if channel == "" || len(aliases) == 0 {
|
||||
continue
|
||||
}
|
||||
seenAlias := make(map[string]struct{}, len(aliases))
|
||||
clean := make([]OAuthModelAlias, 0, len(aliases))
|
||||
for _, entry := range aliases {
|
||||
name := strings.TrimSpace(entry.Name)
|
||||
alias := strings.TrimSpace(entry.Alias)
|
||||
if name == "" || alias == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(name, alias) {
|
||||
continue
|
||||
}
|
||||
aliasKey := strings.ToLower(alias)
|
||||
if _, ok := seenAlias[aliasKey]; ok {
|
||||
continue
|
||||
}
|
||||
seenAlias[aliasKey] = struct{}{}
|
||||
clean = append(clean, OAuthModelAlias{Name: name, Alias: alias, Fork: entry.Fork})
|
||||
}
|
||||
if len(clean) > 0 {
|
||||
out[channel] = clean
|
||||
}
|
||||
}
|
||||
cfg.OAuthModelAlias = out
|
||||
}
|
||||
|
||||
// SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are
|
||||
// not actionable, specifically those missing a BaseURL. It trims whitespace before
|
||||
// evaluation and preserves the relative order of remaining entries.
|
||||
@@ -456,6 +743,7 @@ func (cfg *Config) SanitizeOpenAICompatibility() {
|
||||
for i := range cfg.OpenAICompatibility {
|
||||
e := cfg.OpenAICompatibility[i]
|
||||
e.Name = strings.TrimSpace(e.Name)
|
||||
e.Prefix = normalizeModelPrefix(e.Prefix)
|
||||
e.BaseURL = strings.TrimSpace(e.BaseURL)
|
||||
e.Headers = NormalizeHeaders(e.Headers)
|
||||
if e.BaseURL == "" {
|
||||
@@ -476,6 +764,7 @@ func (cfg *Config) SanitizeCodexKeys() {
|
||||
out := make([]CodexKey, 0, len(cfg.CodexKey))
|
||||
for i := range cfg.CodexKey {
|
||||
e := cfg.CodexKey[i]
|
||||
e.Prefix = normalizeModelPrefix(e.Prefix)
|
||||
e.BaseURL = strings.TrimSpace(e.BaseURL)
|
||||
e.Headers = NormalizeHeaders(e.Headers)
|
||||
e.ExcludedModels = NormalizeExcludedModels(e.ExcludedModels)
|
||||
@@ -494,6 +783,7 @@ func (cfg *Config) SanitizeClaudeKeys() {
|
||||
}
|
||||
for i := range cfg.ClaudeKey {
|
||||
entry := &cfg.ClaudeKey[i]
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||
}
|
||||
@@ -530,6 +820,7 @@ func (cfg *Config) SanitizeGeminiKeys() {
|
||||
if entry.APIKey == "" {
|
||||
continue
|
||||
}
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
@@ -543,6 +834,18 @@ func (cfg *Config) SanitizeGeminiKeys() {
|
||||
cfg.GeminiKey = out
|
||||
}
|
||||
|
||||
func normalizeModelPrefix(prefix string) string {
|
||||
trimmed := strings.TrimSpace(prefix)
|
||||
trimmed = strings.Trim(trimmed, "/")
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(trimmed, "/") {
|
||||
return ""
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func syncInlineAccessProvider(cfg *Config) {
|
||||
if cfg == nil {
|
||||
return
|
||||
@@ -715,7 +1018,7 @@ func sanitizeConfigForPersist(cfg *Config) *Config {
|
||||
}
|
||||
clone := *cfg
|
||||
clone.SDKConfig = cfg.SDKConfig
|
||||
clone.SDKConfig.Access = config.AccessConfig{}
|
||||
clone.SDKConfig.Access = AccessConfig{}
|
||||
return &clone
|
||||
}
|
||||
|
||||
@@ -814,8 +1117,8 @@ func getOrCreateMapValue(mapNode *yaml.Node, key string) *yaml.Node {
|
||||
}
|
||||
|
||||
// mergeMappingPreserve merges keys from src into dst mapping node while preserving
|
||||
// key order and comments of existing keys in dst. Unknown keys from src are appended
|
||||
// to dst at the end, copying their node structure from src.
|
||||
// key order and comments of existing keys in dst. New keys are only added if their
|
||||
// value is non-zero to avoid polluting the config with defaults.
|
||||
func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
@@ -826,20 +1129,19 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
copyNodeShallow(dst, src)
|
||||
return
|
||||
}
|
||||
// Build a lookup of existing keys in dst
|
||||
for i := 0; i+1 < len(src.Content); i += 2 {
|
||||
sk := src.Content[i]
|
||||
sv := src.Content[i+1]
|
||||
idx := findMapKeyIndex(dst, sk.Value)
|
||||
if idx >= 0 {
|
||||
// Merge into existing value node
|
||||
// Merge into existing value node (always update, even to zero values)
|
||||
dv := dst.Content[idx+1]
|
||||
mergeNodePreserve(dv, sv)
|
||||
} else {
|
||||
if shouldSkipEmptyCollectionOnPersist(sk.Value, sv) {
|
||||
// New key: only add if value is non-zero to avoid polluting config with defaults
|
||||
if isZeroValueNode(sv) {
|
||||
continue
|
||||
}
|
||||
// Append new key/value pair by deep-copying from src
|
||||
dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv))
|
||||
}
|
||||
}
|
||||
@@ -922,32 +1224,49 @@ func findMapKeyIndex(mapNode *yaml.Node, key string) int {
|
||||
return -1
|
||||
}
|
||||
|
||||
func shouldSkipEmptyCollectionOnPersist(key string, node *yaml.Node) bool {
|
||||
switch key {
|
||||
case "generative-language-api-key",
|
||||
"gemini-api-key",
|
||||
"vertex-api-key",
|
||||
"claude-api-key",
|
||||
"codex-api-key",
|
||||
"openai-compatibility":
|
||||
return isEmptyCollectionNode(node)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isEmptyCollectionNode(node *yaml.Node) bool {
|
||||
// isZeroValueNode returns true if the YAML node represents a zero/default value
|
||||
// that should not be written as a new key to preserve config cleanliness.
|
||||
// For mappings and sequences, recursively checks if all children are zero values.
|
||||
func isZeroValueNode(node *yaml.Node) bool {
|
||||
if node == nil {
|
||||
return true
|
||||
}
|
||||
switch node.Kind {
|
||||
case yaml.SequenceNode:
|
||||
return len(node.Content) == 0
|
||||
case yaml.ScalarNode:
|
||||
return node.Tag == "!!null"
|
||||
default:
|
||||
return false
|
||||
switch node.Tag {
|
||||
case "!!bool":
|
||||
return node.Value == "false"
|
||||
case "!!int", "!!float":
|
||||
return node.Value == "0" || node.Value == "0.0"
|
||||
case "!!str":
|
||||
return node.Value == ""
|
||||
case "!!null":
|
||||
return true
|
||||
}
|
||||
case yaml.SequenceNode:
|
||||
if len(node.Content) == 0 {
|
||||
return true
|
||||
}
|
||||
// Check if all elements are zero values
|
||||
for _, child := range node.Content {
|
||||
if !isZeroValueNode(child) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case yaml.MappingNode:
|
||||
if len(node.Content) == 0 {
|
||||
return true
|
||||
}
|
||||
// Check if all values are zero values (values are at odd indices)
|
||||
for i := 1; i < len(node.Content); i += 2 {
|
||||
if !isZeroValueNode(node.Content[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// deepCopyNode creates a deep copy of a yaml.Node graph.
|
||||
|
||||
275
internal/config/oauth_model_alias_migration.go
Normal file
275
internal/config/oauth_model_alias_migration.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// antigravityModelConversionTable maps old built-in aliases to actual model names
|
||||
// for the antigravity channel during migration.
|
||||
var antigravityModelConversionTable = map[string]string{
|
||||
"gemini-2.5-computer-use-preview-10-2025": "rev19-uic3-1p",
|
||||
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
||||
"gemini-3-pro-preview": "gemini-3-pro-high",
|
||||
"gemini-3-flash-preview": "gemini-3-flash",
|
||||
"gemini-claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||
"gemini-claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||
"gemini-claude-opus-4-5-thinking": "claude-opus-4-5-thinking",
|
||||
}
|
||||
|
||||
// defaultAntigravityAliases returns the default oauth-model-alias configuration
|
||||
// for the antigravity channel when neither field exists.
|
||||
func defaultAntigravityAliases() []OAuthModelAlias {
|
||||
return []OAuthModelAlias{
|
||||
{Name: "rev19-uic3-1p", Alias: "gemini-2.5-computer-use-preview-10-2025"},
|
||||
{Name: "gemini-3-pro-image", Alias: "gemini-3-pro-image-preview"},
|
||||
{Name: "gemini-3-pro-high", Alias: "gemini-3-pro-preview"},
|
||||
{Name: "gemini-3-flash", Alias: "gemini-3-flash-preview"},
|
||||
{Name: "claude-sonnet-4-5", Alias: "gemini-claude-sonnet-4-5"},
|
||||
{Name: "claude-sonnet-4-5-thinking", Alias: "gemini-claude-sonnet-4-5-thinking"},
|
||||
{Name: "claude-opus-4-5-thinking", Alias: "gemini-claude-opus-4-5-thinking"},
|
||||
}
|
||||
}
|
||||
|
||||
// MigrateOAuthModelAlias checks for and performs migration from oauth-model-mappings
|
||||
// to oauth-model-alias at startup. Returns true if migration was performed.
|
||||
//
|
||||
// Migration flow:
|
||||
// 1. Check if oauth-model-alias exists -> skip migration
|
||||
// 2. Check if oauth-model-mappings exists -> convert and migrate
|
||||
// - For antigravity channel, convert old built-in aliases to actual model names
|
||||
//
|
||||
// 3. Neither exists -> add default antigravity config
|
||||
func MigrateOAuthModelAlias(configFile string) (bool, error) {
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Parse YAML into node tree to preserve structure
|
||||
var root yaml.Node
|
||||
if err := yaml.Unmarshal(data, &root); err != nil {
|
||||
return false, nil
|
||||
}
|
||||
if root.Kind != yaml.DocumentNode || len(root.Content) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
rootMap := root.Content[0]
|
||||
if rootMap == nil || rootMap.Kind != yaml.MappingNode {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check if oauth-model-alias already exists
|
||||
if findMapKeyIndex(rootMap, "oauth-model-alias") >= 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check if oauth-model-mappings exists
|
||||
oldIdx := findMapKeyIndex(rootMap, "oauth-model-mappings")
|
||||
if oldIdx >= 0 {
|
||||
// Migrate from old field
|
||||
return migrateFromOldField(configFile, &root, rootMap, oldIdx)
|
||||
}
|
||||
|
||||
// Neither field exists - add default antigravity config
|
||||
return addDefaultAntigravityConfig(configFile, &root, rootMap)
|
||||
}
|
||||
|
||||
// migrateFromOldField converts oauth-model-mappings to oauth-model-alias
|
||||
func migrateFromOldField(configFile string, root *yaml.Node, rootMap *yaml.Node, oldIdx int) (bool, error) {
|
||||
if oldIdx+1 >= len(rootMap.Content) {
|
||||
return false, nil
|
||||
}
|
||||
oldValue := rootMap.Content[oldIdx+1]
|
||||
if oldValue == nil || oldValue.Kind != yaml.MappingNode {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Parse the old aliases
|
||||
oldAliases := parseOldAliasNode(oldValue)
|
||||
if len(oldAliases) == 0 {
|
||||
// Remove the old field and write
|
||||
removeMapKeyByIndex(rootMap, oldIdx)
|
||||
return writeYAMLNode(configFile, root)
|
||||
}
|
||||
|
||||
// Convert model names for antigravity channel
|
||||
newAliases := make(map[string][]OAuthModelAlias, len(oldAliases))
|
||||
for channel, entries := range oldAliases {
|
||||
converted := make([]OAuthModelAlias, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
newEntry := OAuthModelAlias{
|
||||
Name: entry.Name,
|
||||
Alias: entry.Alias,
|
||||
Fork: entry.Fork,
|
||||
}
|
||||
// Convert model names for antigravity channel
|
||||
if strings.EqualFold(channel, "antigravity") {
|
||||
if actual, ok := antigravityModelConversionTable[entry.Name]; ok {
|
||||
newEntry.Name = actual
|
||||
}
|
||||
}
|
||||
converted = append(converted, newEntry)
|
||||
}
|
||||
newAliases[channel] = converted
|
||||
}
|
||||
|
||||
// For antigravity channel, supplement missing default aliases
|
||||
if antigravityEntries, exists := newAliases["antigravity"]; exists {
|
||||
// Build a set of already configured model names (upstream names)
|
||||
configuredModels := make(map[string]bool, len(antigravityEntries))
|
||||
for _, entry := range antigravityEntries {
|
||||
configuredModels[entry.Name] = true
|
||||
}
|
||||
|
||||
// Add missing default aliases
|
||||
for _, defaultAlias := range defaultAntigravityAliases() {
|
||||
if !configuredModels[defaultAlias.Name] {
|
||||
antigravityEntries = append(antigravityEntries, defaultAlias)
|
||||
}
|
||||
}
|
||||
newAliases["antigravity"] = antigravityEntries
|
||||
}
|
||||
|
||||
// Build new node
|
||||
newNode := buildOAuthModelAliasNode(newAliases)
|
||||
|
||||
// Replace old key with new key and value
|
||||
rootMap.Content[oldIdx].Value = "oauth-model-alias"
|
||||
rootMap.Content[oldIdx+1] = newNode
|
||||
|
||||
return writeYAMLNode(configFile, root)
|
||||
}
|
||||
|
||||
// addDefaultAntigravityConfig adds the default antigravity configuration
|
||||
func addDefaultAntigravityConfig(configFile string, root *yaml.Node, rootMap *yaml.Node) (bool, error) {
|
||||
defaults := map[string][]OAuthModelAlias{
|
||||
"antigravity": defaultAntigravityAliases(),
|
||||
}
|
||||
newNode := buildOAuthModelAliasNode(defaults)
|
||||
|
||||
// Add new key-value pair
|
||||
keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "oauth-model-alias"}
|
||||
rootMap.Content = append(rootMap.Content, keyNode, newNode)
|
||||
|
||||
return writeYAMLNode(configFile, root)
|
||||
}
|
||||
|
||||
// parseOldAliasNode parses the old oauth-model-mappings node structure
|
||||
func parseOldAliasNode(node *yaml.Node) map[string][]OAuthModelAlias {
|
||||
if node == nil || node.Kind != yaml.MappingNode {
|
||||
return nil
|
||||
}
|
||||
result := make(map[string][]OAuthModelAlias)
|
||||
for i := 0; i+1 < len(node.Content); i += 2 {
|
||||
channelNode := node.Content[i]
|
||||
entriesNode := node.Content[i+1]
|
||||
if channelNode == nil || entriesNode == nil {
|
||||
continue
|
||||
}
|
||||
channel := strings.ToLower(strings.TrimSpace(channelNode.Value))
|
||||
if channel == "" || entriesNode.Kind != yaml.SequenceNode {
|
||||
continue
|
||||
}
|
||||
entries := make([]OAuthModelAlias, 0, len(entriesNode.Content))
|
||||
for _, entryNode := range entriesNode.Content {
|
||||
if entryNode == nil || entryNode.Kind != yaml.MappingNode {
|
||||
continue
|
||||
}
|
||||
entry := parseAliasEntry(entryNode)
|
||||
if entry.Name != "" && entry.Alias != "" {
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
}
|
||||
if len(entries) > 0 {
|
||||
result[channel] = entries
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// parseAliasEntry parses a single alias entry node
|
||||
func parseAliasEntry(node *yaml.Node) OAuthModelAlias {
|
||||
var entry OAuthModelAlias
|
||||
for i := 0; i+1 < len(node.Content); i += 2 {
|
||||
keyNode := node.Content[i]
|
||||
valNode := node.Content[i+1]
|
||||
if keyNode == nil || valNode == nil {
|
||||
continue
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(keyNode.Value)) {
|
||||
case "name":
|
||||
entry.Name = strings.TrimSpace(valNode.Value)
|
||||
case "alias":
|
||||
entry.Alias = strings.TrimSpace(valNode.Value)
|
||||
case "fork":
|
||||
entry.Fork = strings.ToLower(strings.TrimSpace(valNode.Value)) == "true"
|
||||
}
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
// buildOAuthModelAliasNode creates a YAML node for oauth-model-alias
|
||||
func buildOAuthModelAliasNode(aliases map[string][]OAuthModelAlias) *yaml.Node {
|
||||
node := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
||||
for channel, entries := range aliases {
|
||||
channelNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: channel}
|
||||
entriesNode := &yaml.Node{Kind: yaml.SequenceNode, Tag: "!!seq"}
|
||||
for _, entry := range entries {
|
||||
entryNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
||||
entryNode.Content = append(entryNode.Content,
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "name"},
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Name},
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "alias"},
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Alias},
|
||||
)
|
||||
if entry.Fork {
|
||||
entryNode.Content = append(entryNode.Content,
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "fork"},
|
||||
&yaml.Node{Kind: yaml.ScalarNode, Tag: "!!bool", Value: "true"},
|
||||
)
|
||||
}
|
||||
entriesNode.Content = append(entriesNode.Content, entryNode)
|
||||
}
|
||||
node.Content = append(node.Content, channelNode, entriesNode)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// removeMapKeyByIndex removes a key-value pair from a mapping node by index
|
||||
func removeMapKeyByIndex(mapNode *yaml.Node, keyIdx int) {
|
||||
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
|
||||
return
|
||||
}
|
||||
if keyIdx < 0 || keyIdx+1 >= len(mapNode.Content) {
|
||||
return
|
||||
}
|
||||
mapNode.Content = append(mapNode.Content[:keyIdx], mapNode.Content[keyIdx+2:]...)
|
||||
}
|
||||
|
||||
// writeYAMLNode writes the YAML node tree back to file
|
||||
func writeYAMLNode(configFile string, root *yaml.Node) (bool, error) {
|
||||
f, err := os.Create(configFile)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
enc := yaml.NewEncoder(f)
|
||||
enc.SetIndent(2)
|
||||
if err := enc.Encode(root); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := enc.Close(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
242
internal/config/oauth_model_alias_migration_test.go
Normal file
242
internal/config/oauth_model_alias_migration_test.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestMigrateOAuthModelAlias_SkipsIfNewFieldExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
content := `oauth-model-alias:
|
||||
gemini-cli:
|
||||
- name: "gemini-2.5-pro"
|
||||
alias: "g2.5p"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if migrated {
|
||||
t.Fatal("expected no migration when oauth-model-alias already exists")
|
||||
}
|
||||
|
||||
// Verify file unchanged
|
||||
data, _ := os.ReadFile(configFile)
|
||||
if !strings.Contains(string(data), "oauth-model-alias:") {
|
||||
t.Fatal("file should still contain oauth-model-alias")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_MigratesOldField(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
content := `oauth-model-mappings:
|
||||
gemini-cli:
|
||||
- name: "gemini-2.5-pro"
|
||||
alias: "g2.5p"
|
||||
fork: true
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to occur")
|
||||
}
|
||||
|
||||
// Verify new field exists and old field removed
|
||||
data, _ := os.ReadFile(configFile)
|
||||
if strings.Contains(string(data), "oauth-model-mappings:") {
|
||||
t.Fatal("old field should be removed")
|
||||
}
|
||||
if !strings.Contains(string(data), "oauth-model-alias:") {
|
||||
t.Fatal("new field should exist")
|
||||
}
|
||||
|
||||
// Parse and verify structure
|
||||
var root yaml.Node
|
||||
if err := yaml.Unmarshal(data, &root); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_ConvertsAntigravityModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
// Use old model names that should be converted
|
||||
content := `oauth-model-mappings:
|
||||
antigravity:
|
||||
- name: "gemini-2.5-computer-use-preview-10-2025"
|
||||
alias: "computer-use"
|
||||
- name: "gemini-3-pro-preview"
|
||||
alias: "g3p"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to occur")
|
||||
}
|
||||
|
||||
// Verify model names were converted
|
||||
data, _ := os.ReadFile(configFile)
|
||||
content = string(data)
|
||||
if !strings.Contains(content, "rev19-uic3-1p") {
|
||||
t.Fatal("expected gemini-2.5-computer-use-preview-10-2025 to be converted to rev19-uic3-1p")
|
||||
}
|
||||
if !strings.Contains(content, "gemini-3-pro-high") {
|
||||
t.Fatal("expected gemini-3-pro-preview to be converted to gemini-3-pro-high")
|
||||
}
|
||||
|
||||
// Verify missing default aliases were supplemented
|
||||
if !strings.Contains(content, "gemini-3-pro-image") {
|
||||
t.Fatal("expected missing default alias gemini-3-pro-image to be added")
|
||||
}
|
||||
if !strings.Contains(content, "gemini-3-flash") {
|
||||
t.Fatal("expected missing default alias gemini-3-flash to be added")
|
||||
}
|
||||
if !strings.Contains(content, "claude-sonnet-4-5") {
|
||||
t.Fatal("expected missing default alias claude-sonnet-4-5 to be added")
|
||||
}
|
||||
if !strings.Contains(content, "claude-sonnet-4-5-thinking") {
|
||||
t.Fatal("expected missing default alias claude-sonnet-4-5-thinking to be added")
|
||||
}
|
||||
if !strings.Contains(content, "claude-opus-4-5-thinking") {
|
||||
t.Fatal("expected missing default alias claude-opus-4-5-thinking to be added")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_AddsDefaultIfNeitherExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
content := `debug: true
|
||||
port: 8080
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to add default config")
|
||||
}
|
||||
|
||||
// Verify default antigravity config was added
|
||||
data, _ := os.ReadFile(configFile)
|
||||
content = string(data)
|
||||
if !strings.Contains(content, "oauth-model-alias:") {
|
||||
t.Fatal("expected oauth-model-alias to be added")
|
||||
}
|
||||
if !strings.Contains(content, "antigravity:") {
|
||||
t.Fatal("expected antigravity channel to be added")
|
||||
}
|
||||
if !strings.Contains(content, "rev19-uic3-1p") {
|
||||
t.Fatal("expected default antigravity aliases to include rev19-uic3-1p")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_PreservesOtherConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
content := `debug: true
|
||||
port: 8080
|
||||
oauth-model-mappings:
|
||||
gemini-cli:
|
||||
- name: "test"
|
||||
alias: "t"
|
||||
api-keys:
|
||||
- "key1"
|
||||
- "key2"
|
||||
`
|
||||
if err := os.WriteFile(configFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !migrated {
|
||||
t.Fatal("expected migration to occur")
|
||||
}
|
||||
|
||||
// Verify other config preserved
|
||||
data, _ := os.ReadFile(configFile)
|
||||
content = string(data)
|
||||
if !strings.Contains(content, "debug: true") {
|
||||
t.Fatal("expected debug field to be preserved")
|
||||
}
|
||||
if !strings.Contains(content, "port: 8080") {
|
||||
t.Fatal("expected port field to be preserved")
|
||||
}
|
||||
if !strings.Contains(content, "api-keys:") {
|
||||
t.Fatal("expected api-keys field to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_NonexistentFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias("/nonexistent/path/config.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for nonexistent file: %v", err)
|
||||
}
|
||||
if migrated {
|
||||
t.Fatal("expected no migration for nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_EmptyFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
configFile := filepath.Join(dir, "config.yaml")
|
||||
|
||||
if err := os.WriteFile(configFile, []byte(""), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
migrated, err := MigrateOAuthModelAlias(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if migrated {
|
||||
t.Fatal("expected no migration for empty file")
|
||||
}
|
||||
}
|
||||
56
internal/config/oauth_model_alias_test.go
Normal file
56
internal/config/oauth_model_alias_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeOAuthModelAlias_PreservesForkFlag(t *testing.T) {
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
" CoDeX ": {
|
||||
{Name: " gpt-5 ", Alias: " g5 ", Fork: true},
|
||||
{Name: "gpt-6", Alias: "g6"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
aliases := cfg.OAuthModelAlias["codex"]
|
||||
if len(aliases) != 2 {
|
||||
t.Fatalf("expected 2 sanitized aliases, got %d", len(aliases))
|
||||
}
|
||||
if aliases[0].Name != "gpt-5" || aliases[0].Alias != "g5" || !aliases[0].Fork {
|
||||
t.Fatalf("expected first alias to be gpt-5->g5 fork=true, got name=%q alias=%q fork=%v", aliases[0].Name, aliases[0].Alias, aliases[0].Fork)
|
||||
}
|
||||
if aliases[1].Name != "gpt-6" || aliases[1].Alias != "g6" || aliases[1].Fork {
|
||||
t.Fatalf("expected second alias to be gpt-6->g6 fork=false, got name=%q alias=%q fork=%v", aliases[1].Name, aliases[1].Alias, aliases[1].Fork)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_AllowsMultipleAliasesForSameName(t *testing.T) {
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"antigravity": {
|
||||
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101", Fork: true},
|
||||
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101-thinking", Fork: true},
|
||||
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5", Fork: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
aliases := cfg.OAuthModelAlias["antigravity"]
|
||||
expected := []OAuthModelAlias{
|
||||
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101", Fork: true},
|
||||
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101-thinking", Fork: true},
|
||||
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5", Fork: true},
|
||||
}
|
||||
if len(aliases) != len(expected) {
|
||||
t.Fatalf("expected %d sanitized aliases, got %d", len(expected), len(aliases))
|
||||
}
|
||||
for i, exp := range expected {
|
||||
if aliases[i].Name != exp.Name || aliases[i].Alias != exp.Alias || aliases[i].Fork != exp.Fork {
|
||||
t.Fatalf("expected alias %d to be name=%q alias=%q fork=%v, got name=%q alias=%q fork=%v", i, exp.Name, exp.Alias, exp.Fork, aliases[i].Name, aliases[i].Alias, aliases[i].Fork)
|
||||
}
|
||||
}
|
||||
}
|
||||
106
internal/config/sdk_config.go
Normal file
106
internal/config/sdk_config.go
Normal file
@@ -0,0 +1,106 @@
|
||||
// Package config provides configuration management for the CLI Proxy API server.
|
||||
// It handles loading and parsing YAML configuration files, and provides structured
|
||||
// access to application settings including server port, authentication directory,
|
||||
// debug settings, proxy configuration, and API keys.
|
||||
package config
|
||||
|
||||
// SDKConfig represents the application's configuration, loaded from a YAML file.
|
||||
type SDKConfig struct {
|
||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
|
||||
// to target prefixed credentials. When false, unprefixed model requests may use prefixed
|
||||
// credentials as well.
|
||||
ForceModelPrefix bool `yaml:"force-model-prefix" json:"force-model-prefix"`
|
||||
|
||||
// RequestLog enables or disables detailed request logging functionality.
|
||||
RequestLog bool `yaml:"request-log" json:"request-log"`
|
||||
|
||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||
|
||||
// Access holds request authentication provider configuration.
|
||||
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
|
||||
|
||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||
|
||||
// NonStreamKeepAliveInterval controls how often blank lines are emitted for non-streaming responses.
|
||||
// <= 0 disables keep-alives. Value is in seconds.
|
||||
NonStreamKeepAliveInterval int `yaml:"nonstream-keepalive-interval,omitempty" json:"nonstream-keepalive-interval,omitempty"`
|
||||
}
|
||||
|
||||
// StreamingConfig holds server streaming behavior configuration.
|
||||
type StreamingConfig struct {
|
||||
// KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n").
|
||||
// <= 0 disables keep-alives. Default is 0.
|
||||
KeepAliveSeconds int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"`
|
||||
|
||||
// BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent,
|
||||
// to allow auth rotation / transient recovery.
|
||||
// <= 0 disables bootstrap retries. Default is 0.
|
||||
BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
|
||||
}
|
||||
|
||||
// AccessConfig groups request authentication providers.
|
||||
type AccessConfig struct {
|
||||
// Providers lists configured authentication providers.
|
||||
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
|
||||
}
|
||||
|
||||
// AccessProvider describes a request authentication provider entry.
|
||||
type AccessProvider struct {
|
||||
// Name is the instance identifier for the provider.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Type selects the provider implementation registered via the SDK.
|
||||
Type string `yaml:"type" json:"type"`
|
||||
|
||||
// SDK optionally names a third-party SDK module providing this provider.
|
||||
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
|
||||
|
||||
// APIKeys lists inline keys for providers that require them.
|
||||
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
|
||||
|
||||
// Config passes provider-specific options to the implementation.
|
||||
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys.
|
||||
AccessProviderTypeConfigAPIKey = "config-api-key"
|
||||
|
||||
// DefaultAccessProviderName is applied when no provider name is supplied.
|
||||
DefaultAccessProviderName = "config-inline"
|
||||
)
|
||||
|
||||
// ConfigAPIKeyProvider returns the first inline API key provider if present.
|
||||
func (c *SDKConfig) ConfigAPIKeyProvider() *AccessProvider {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
for i := range c.Access.Providers {
|
||||
if c.Access.Providers[i].Type == AccessProviderTypeConfigAPIKey {
|
||||
if c.Access.Providers[i].Name == "" {
|
||||
c.Access.Providers[i].Name = DefaultAccessProviderName
|
||||
}
|
||||
return &c.Access.Providers[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration.
|
||||
// It returns nil when no keys are supplied.
|
||||
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
provider := &AccessProvider{
|
||||
Name: DefaultAccessProviderName,
|
||||
Type: AccessProviderTypeConfigAPIKey,
|
||||
APIKeys: append([]string(nil), keys...),
|
||||
}
|
||||
return provider
|
||||
}
|
||||
@@ -13,6 +13,13 @@ type VertexCompatKey struct {
|
||||
// Maps to the x-goog-api-key header.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Priority controls selection preference when multiple credentials match.
|
||||
// Higher values are preferred; defaults to 0.
|
||||
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||
|
||||
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Vertex-compatible API endpoint.
|
||||
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
|
||||
// Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..."
|
||||
@@ -29,6 +36,9 @@ type VertexCompatKey struct {
|
||||
Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"`
|
||||
}
|
||||
|
||||
func (k VertexCompatKey) GetAPIKey() string { return k.APIKey }
|
||||
func (k VertexCompatKey) GetBaseURL() string { return k.BaseURL }
|
||||
|
||||
// VertexCompatModel represents a model configuration for Vertex compatibility,
|
||||
// including the actual model name and its alias for API routing.
|
||||
type VertexCompatModel struct {
|
||||
@@ -39,6 +49,9 @@ type VertexCompatModel struct {
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
func (m VertexCompatModel) GetName() string { return m.Name }
|
||||
func (m VertexCompatModel) GetAlias() string { return m.Alias }
|
||||
|
||||
// SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials.
|
||||
func (cfg *Config) SanitizeVertexCompatKeys() {
|
||||
if cfg == nil {
|
||||
@@ -53,6 +66,7 @@ func (cfg *Config) SanitizeVertexCompatKeys() {
|
||||
if entry.APIKey == "" {
|
||||
continue
|
||||
}
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
if entry.BaseURL == "" {
|
||||
// BaseURL is required for Vertex API key entries
|
||||
|
||||
@@ -4,9 +4,11 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -14,11 +16,24 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// aiAPIPrefixes defines path prefixes for AI API requests that should have request ID tracking.
|
||||
var aiAPIPrefixes = []string{
|
||||
"/v1/chat/completions",
|
||||
"/v1/completions",
|
||||
"/v1/messages",
|
||||
"/v1/responses",
|
||||
"/v1beta/models/",
|
||||
"/api/provider/",
|
||||
}
|
||||
|
||||
const skipGinLogKey = "__gin_skip_request_logging__"
|
||||
|
||||
// GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses
|
||||
// using logrus. It captures request details including method, path, status code, latency,
|
||||
// client IP, and any error messages, formatting them in a Gin-style log format.
|
||||
// client IP, and any error messages. Request ID is only added for AI API requests.
|
||||
//
|
||||
// Output format (AI API): [2025-12-23 20:14:10] [info ] | a1b2c3d4 | 200 | 23.559s | ...
|
||||
// Output format (others): [2025-12-23 20:14:10] [info ] | -------- | 200 | 23.559s | ...
|
||||
//
|
||||
// Returns:
|
||||
// - gin.HandlerFunc: A middleware handler for request logging
|
||||
@@ -28,6 +43,15 @@ func GinLogrusLogger() gin.HandlerFunc {
|
||||
path := c.Request.URL.Path
|
||||
raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||
|
||||
// Only generate request ID for AI API paths
|
||||
var requestID string
|
||||
if isAIAPIPath(path) {
|
||||
requestID = GenerateRequestID()
|
||||
SetGinRequestID(c, requestID)
|
||||
ctx := WithRequestID(c.Request.Context(), requestID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
c.Next()
|
||||
|
||||
if shouldSkipGinRequestLogging(c) {
|
||||
@@ -49,23 +73,38 @@ func GinLogrusLogger() gin.HandlerFunc {
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String()
|
||||
timestamp := time.Now().Format("2006/01/02 - 15:04:05")
|
||||
logLine := fmt.Sprintf("[GIN] %s | %3d | %13v | %15s | %-7s \"%s\"", timestamp, statusCode, latency, clientIP, method, path)
|
||||
|
||||
if requestID == "" {
|
||||
requestID = "--------"
|
||||
}
|
||||
logLine := fmt.Sprintf("%3d | %13v | %15s | %-7s \"%s\"", statusCode, latency, clientIP, method, path)
|
||||
if errorMessage != "" {
|
||||
logLine = logLine + " | " + errorMessage
|
||||
}
|
||||
|
||||
entry := log.WithField("request_id", requestID)
|
||||
|
||||
switch {
|
||||
case statusCode >= http.StatusInternalServerError:
|
||||
log.Error(logLine)
|
||||
entry.Error(logLine)
|
||||
case statusCode >= http.StatusBadRequest:
|
||||
log.Warn(logLine)
|
||||
entry.Warn(logLine)
|
||||
default:
|
||||
log.Info(logLine)
|
||||
entry.Info(logLine)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isAIAPIPath checks if the given path is an AI API endpoint that should have request ID tracking.
|
||||
func isAIAPIPath(path string) bool {
|
||||
for _, prefix := range aiAPIPrefixes {
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GinLogrusRecovery returns a Gin middleware handler that recovers from panics and logs
|
||||
// them using logrus. When a panic occurs, it captures the panic value, stack trace,
|
||||
// and request path, then returns a 500 Internal Server Error response to the client.
|
||||
@@ -74,6 +113,11 @@ func GinLogrusLogger() gin.HandlerFunc {
|
||||
// - gin.HandlerFunc: A middleware handler for panic recovery
|
||||
func GinLogrusRecovery() gin.HandlerFunc {
|
||||
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
|
||||
if err, ok := recovered.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
||||
// Let net/http handle ErrAbortHandler so the connection is aborted without noisy stack logs.
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"panic": recovered,
|
||||
"stack": string(debug.Stack()),
|
||||
|
||||
60
internal/logging/gin_logger_test.go
Normal file
60
internal/logging/gin_logger_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestGinLogrusRecoveryRepanicsErrAbortHandler(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
engine := gin.New()
|
||||
engine.Use(GinLogrusRecovery())
|
||||
engine.GET("/abort", func(c *gin.Context) {
|
||||
panic(http.ErrAbortHandler)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/abort", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
defer func() {
|
||||
recovered := recover()
|
||||
if recovered == nil {
|
||||
t.Fatalf("expected panic, got nil")
|
||||
}
|
||||
err, ok := recovered.(error)
|
||||
if !ok {
|
||||
t.Fatalf("expected error panic, got %T", recovered)
|
||||
}
|
||||
if !errors.Is(err, http.ErrAbortHandler) {
|
||||
t.Fatalf("expected ErrAbortHandler, got %v", err)
|
||||
}
|
||||
if err != http.ErrAbortHandler {
|
||||
t.Fatalf("expected exact ErrAbortHandler sentinel, got %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
engine.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
func TestGinLogrusRecoveryHandlesRegularPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
engine := gin.New()
|
||||
engine.Use(GinLogrusRecovery())
|
||||
engine.GET("/panic", func(c *gin.Context) {
|
||||
panic("boom")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/panic", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
engine.ServeHTTP(recorder, req)
|
||||
if recorder.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
@@ -24,9 +25,13 @@ var (
|
||||
)
|
||||
|
||||
// LogFormatter defines a custom log format for logrus.
|
||||
// This formatter adds timestamp, level, and source location to each log entry.
|
||||
// This formatter adds timestamp, level, request ID, and source location to each log entry.
|
||||
// Format: [2025-12-23 20:14:04] [debug] [manager.go:524] | a1b2c3d4 | Use API key sk-9...0RHO for model gpt-5.2
|
||||
type LogFormatter struct{}
|
||||
|
||||
// logFieldOrder defines the display order for common log fields.
|
||||
var logFieldOrder = []string{"provider", "model", "mode", "budget", "level", "original_mode", "original_value", "min", "max", "clamped_to", "error"}
|
||||
|
||||
// Format renders a single log entry with custom formatting.
|
||||
func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
||||
var buffer *bytes.Buffer
|
||||
@@ -38,16 +43,38 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
||||
|
||||
timestamp := entry.Time.Format("2006-01-02 15:04:05")
|
||||
message := strings.TrimRight(entry.Message, "\r\n")
|
||||
|
||||
// Handle nil Caller (can happen with some log entries)
|
||||
callerFile := "unknown"
|
||||
callerLine := 0
|
||||
if entry.Caller != nil {
|
||||
callerFile = filepath.Base(entry.Caller.File)
|
||||
callerLine = entry.Caller.Line
|
||||
|
||||
reqID := "--------"
|
||||
if id, ok := entry.Data["request_id"].(string); ok && id != "" {
|
||||
reqID = id
|
||||
}
|
||||
|
||||
level := entry.Level.String()
|
||||
if level == "warning" {
|
||||
level = "warn"
|
||||
}
|
||||
levelStr := fmt.Sprintf("%-5s", level)
|
||||
|
||||
// Build fields string (only print fields in logFieldOrder)
|
||||
var fieldsStr string
|
||||
if len(entry.Data) > 0 {
|
||||
var fields []string
|
||||
for _, k := range logFieldOrder {
|
||||
if v, ok := entry.Data[k]; ok {
|
||||
fields = append(fields, fmt.Sprintf("%s=%v", k, v))
|
||||
}
|
||||
}
|
||||
if len(fields) > 0 {
|
||||
fieldsStr = " " + strings.Join(fields, " ")
|
||||
}
|
||||
}
|
||||
|
||||
var formatted string
|
||||
if entry.Caller != nil {
|
||||
formatted = fmt.Sprintf("[%s] [%s] [%s] [%s:%d] %s%s\n", timestamp, reqID, levelStr, filepath.Base(entry.Caller.File), entry.Caller.Line, message, fieldsStr)
|
||||
} else {
|
||||
formatted = fmt.Sprintf("[%s] [%s] [%s] %s%s\n", timestamp, reqID, levelStr, message, fieldsStr)
|
||||
}
|
||||
|
||||
formatted := fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, callerFile, callerLine, message)
|
||||
buffer.WriteString(formatted)
|
||||
|
||||
return buffer.Bytes(), nil
|
||||
@@ -75,40 +102,81 @@ func SetupBaseLogger() {
|
||||
})
|
||||
}
|
||||
|
||||
// isDirWritable checks if the specified directory exists and is writable by attempting to create and remove a test file.
|
||||
func isDirWritable(dir string) bool {
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil || !info.IsDir() {
|
||||
return false
|
||||
}
|
||||
|
||||
testFile := filepath.Join(dir, ".perm_test")
|
||||
f, err := os.Create(testFile)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
_ = os.Remove(testFile)
|
||||
}()
|
||||
return true
|
||||
}
|
||||
|
||||
// ResolveLogDirectory determines the directory used for application logs.
|
||||
func ResolveLogDirectory(cfg *config.Config) string {
|
||||
logDir := "logs"
|
||||
if base := util.WritablePath(); base != "" {
|
||||
return filepath.Join(base, "logs")
|
||||
}
|
||||
if cfg == nil {
|
||||
return logDir
|
||||
}
|
||||
if !isDirWritable(logDir) {
|
||||
authDir := strings.TrimSpace(cfg.AuthDir)
|
||||
if authDir != "" {
|
||||
logDir = filepath.Join(authDir, "logs")
|
||||
}
|
||||
}
|
||||
return logDir
|
||||
}
|
||||
|
||||
// ConfigureLogOutput switches the global log destination between rotating files and stdout.
|
||||
func ConfigureLogOutput(loggingToFile bool) error {
|
||||
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
|
||||
// until the total size is within the limit.
|
||||
func ConfigureLogOutput(cfg *config.Config) error {
|
||||
SetupBaseLogger()
|
||||
|
||||
writerMu.Lock()
|
||||
defer writerMu.Unlock()
|
||||
|
||||
if loggingToFile {
|
||||
logDir := "logs"
|
||||
if base := util.WritablePath(); base != "" {
|
||||
logDir = filepath.Join(base, "logs")
|
||||
}
|
||||
logDir := ResolveLogDirectory(cfg)
|
||||
|
||||
protectedPath := ""
|
||||
if cfg.LoggingToFile {
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
return fmt.Errorf("logging: failed to create log directory: %w", err)
|
||||
}
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
}
|
||||
protectedPath = filepath.Join(logDir, "main.log")
|
||||
logWriter = &lumberjack.Logger{
|
||||
Filename: filepath.Join(logDir, "main.log"),
|
||||
Filename: protectedPath,
|
||||
MaxSize: 10,
|
||||
MaxBackups: 0,
|
||||
MaxAge: 0,
|
||||
Compress: false,
|
||||
}
|
||||
log.SetOutput(logWriter)
|
||||
return nil
|
||||
} else {
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
logWriter = nil
|
||||
}
|
||||
log.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
logWriter = nil
|
||||
}
|
||||
log.SetOutput(os.Stdout)
|
||||
configureLogDirCleanerLocked(logDir, cfg.LogsMaxTotalSizeMB, protectedPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -116,6 +184,8 @@ func closeLogOutputs() {
|
||||
writerMu.Lock()
|
||||
defer writerMu.Unlock()
|
||||
|
||||
stopLogDirCleanerLocked()
|
||||
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
logWriter = nil
|
||||
|
||||
166
internal/logging/log_dir_cleaner.go
Normal file
166
internal/logging/log_dir_cleaner.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const logDirCleanerInterval = time.Minute
|
||||
|
||||
var logDirCleanerCancel context.CancelFunc
|
||||
|
||||
func configureLogDirCleanerLocked(logDir string, maxTotalSizeMB int, protectedPath string) {
|
||||
stopLogDirCleanerLocked()
|
||||
|
||||
if maxTotalSizeMB <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
maxBytes := int64(maxTotalSizeMB) * 1024 * 1024
|
||||
if maxBytes <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
dir := strings.TrimSpace(logDir)
|
||||
if dir == "" {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
logDirCleanerCancel = cancel
|
||||
go runLogDirCleaner(ctx, filepath.Clean(dir), maxBytes, strings.TrimSpace(protectedPath))
|
||||
}
|
||||
|
||||
func stopLogDirCleanerLocked() {
|
||||
if logDirCleanerCancel == nil {
|
||||
return
|
||||
}
|
||||
logDirCleanerCancel()
|
||||
logDirCleanerCancel = nil
|
||||
}
|
||||
|
||||
func runLogDirCleaner(ctx context.Context, logDir string, maxBytes int64, protectedPath string) {
|
||||
ticker := time.NewTicker(logDirCleanerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
cleanOnce := func() {
|
||||
deleted, errClean := enforceLogDirSizeLimit(logDir, maxBytes, protectedPath)
|
||||
if errClean != nil {
|
||||
log.WithError(errClean).Warn("logging: failed to enforce log directory size limit")
|
||||
return
|
||||
}
|
||||
if deleted > 0 {
|
||||
log.Debugf("logging: removed %d old log file(s) to enforce log directory size limit", deleted)
|
||||
}
|
||||
}
|
||||
|
||||
cleanOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
cleanOnce()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func enforceLogDirSizeLimit(logDir string, maxBytes int64, protectedPath string) (int, error) {
|
||||
if maxBytes <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
dir := strings.TrimSpace(logDir)
|
||||
if dir == "" {
|
||||
return 0, nil
|
||||
}
|
||||
dir = filepath.Clean(dir)
|
||||
|
||||
entries, errRead := os.ReadDir(dir)
|
||||
if errRead != nil {
|
||||
if os.IsNotExist(errRead) {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, errRead
|
||||
}
|
||||
|
||||
protected := strings.TrimSpace(protectedPath)
|
||||
if protected != "" {
|
||||
protected = filepath.Clean(protected)
|
||||
}
|
||||
|
||||
type logFile struct {
|
||||
path string
|
||||
size int64
|
||||
modTime time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
files []logFile
|
||||
total int64
|
||||
)
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !isLogFileName(name) {
|
||||
continue
|
||||
}
|
||||
info, errInfo := entry.Info()
|
||||
if errInfo != nil {
|
||||
continue
|
||||
}
|
||||
if !info.Mode().IsRegular() {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(dir, name)
|
||||
files = append(files, logFile{
|
||||
path: path,
|
||||
size: info.Size(),
|
||||
modTime: info.ModTime(),
|
||||
})
|
||||
total += info.Size()
|
||||
}
|
||||
|
||||
if total <= maxBytes {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].modTime.Before(files[j].modTime)
|
||||
})
|
||||
|
||||
deleted := 0
|
||||
for _, file := range files {
|
||||
if total <= maxBytes {
|
||||
break
|
||||
}
|
||||
if protected != "" && filepath.Clean(file.path) == protected {
|
||||
continue
|
||||
}
|
||||
if errRemove := os.Remove(file.path); errRemove != nil {
|
||||
log.WithError(errRemove).Warnf("logging: failed to remove old log file: %s", filepath.Base(file.path))
|
||||
continue
|
||||
}
|
||||
total -= file.size
|
||||
deleted++
|
||||
}
|
||||
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func isLogFileName(name string) bool {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
lower := strings.ToLower(trimmed)
|
||||
return strings.HasSuffix(lower, ".log") || strings.HasSuffix(lower, ".log.gz")
|
||||
}
|
||||
70
internal/logging/log_dir_cleaner_test.go
Normal file
70
internal/logging/log_dir_cleaner_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestEnforceLogDirSizeLimitDeletesOldest(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
writeLogFile(t, filepath.Join(dir, "old.log"), 60, time.Unix(1, 0))
|
||||
writeLogFile(t, filepath.Join(dir, "mid.log"), 60, time.Unix(2, 0))
|
||||
protected := filepath.Join(dir, "main.log")
|
||||
writeLogFile(t, protected, 60, time.Unix(3, 0))
|
||||
|
||||
deleted, err := enforceLogDirSizeLimit(dir, 120, protected)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if deleted != 1 {
|
||||
t.Fatalf("expected 1 deleted file, got %d", deleted)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(dir, "old.log")); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected old.log to be removed, stat error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(dir, "mid.log")); err != nil {
|
||||
t.Fatalf("expected mid.log to remain, stat error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(protected); err != nil {
|
||||
t.Fatalf("expected protected main.log to remain, stat error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceLogDirSizeLimitSkipsProtected(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
protected := filepath.Join(dir, "main.log")
|
||||
writeLogFile(t, protected, 200, time.Unix(1, 0))
|
||||
writeLogFile(t, filepath.Join(dir, "other.log"), 50, time.Unix(2, 0))
|
||||
|
||||
deleted, err := enforceLogDirSizeLimit(dir, 100, protected)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if deleted != 1 {
|
||||
t.Fatalf("expected 1 deleted file, got %d", deleted)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(protected); err != nil {
|
||||
t.Fatalf("expected protected main.log to remain, stat error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(dir, "other.log")); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected other.log to be removed, stat error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func writeLogFile(t *testing.T, path string, size int, modTime time.Time) {
|
||||
t.Helper()
|
||||
|
||||
data := make([]byte, size)
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
t.Fatalf("write file: %v", err)
|
||||
}
|
||||
if err := os.Chtimes(path, modTime, modTime); err != nil {
|
||||
t.Fatalf("set times: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
@@ -25,6 +26,8 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
var requestLogID atomic.Uint64
|
||||
|
||||
// RequestLogger defines the interface for logging HTTP requests and responses.
|
||||
// It provides methods for logging both regular and streaming HTTP request/response cycles.
|
||||
type RequestLogger interface {
|
||||
@@ -40,10 +43,11 @@ type RequestLogger interface {
|
||||
// - response: The raw response data
|
||||
// - apiRequest: The API request data
|
||||
// - apiResponse: The API response data
|
||||
// - requestID: Optional request ID for log file naming
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if logging fails, nil otherwise
|
||||
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error
|
||||
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error
|
||||
|
||||
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
|
||||
//
|
||||
@@ -52,11 +56,12 @@ type RequestLogger interface {
|
||||
// - method: The HTTP method
|
||||
// - headers: The request headers
|
||||
// - body: The request body
|
||||
// - requestID: Optional request ID for log file naming
|
||||
//
|
||||
// Returns:
|
||||
// - StreamingLogWriter: A writer for streaming response chunks
|
||||
// - error: An error if logging initialization fails, nil otherwise
|
||||
LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error)
|
||||
LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error)
|
||||
|
||||
// IsEnabled returns whether request logging is currently enabled.
|
||||
//
|
||||
@@ -174,20 +179,21 @@ func (l *FileRequestLogger) SetEnabled(enabled bool) {
|
||||
// - response: The raw response data
|
||||
// - apiRequest: The API request data
|
||||
// - apiResponse: The API response data
|
||||
// - requestID: Optional request ID for log file naming
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if logging fails, nil otherwise
|
||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false)
|
||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID)
|
||||
}
|
||||
|
||||
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
||||
// The force flag allows writing error logs even when regular request logging is disabled.
|
||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force)
|
||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID)
|
||||
}
|
||||
|
||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
|
||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
|
||||
if !l.enabled && !force {
|
||||
return nil
|
||||
}
|
||||
@@ -197,26 +203,59 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
|
||||
return fmt.Errorf("failed to create logs directory: %w", errEnsure)
|
||||
}
|
||||
|
||||
// Generate filename
|
||||
filename := l.generateFilename(url)
|
||||
// Generate filename with request ID
|
||||
filename := l.generateFilename(url, requestID)
|
||||
if force && !l.enabled {
|
||||
filename = l.generateErrorFilename(url)
|
||||
filename = l.generateErrorFilename(url, requestID)
|
||||
}
|
||||
filePath := filepath.Join(l.logsDir, filename)
|
||||
|
||||
// Decompress response if needed
|
||||
decompressedResponse, err := l.decompressResponse(responseHeaders, response)
|
||||
if err != nil {
|
||||
// If decompression fails, log the error but continue with original response
|
||||
decompressedResponse = append(response, []byte(fmt.Sprintf("\n[DECOMPRESSION ERROR: %v]", err))...)
|
||||
requestBodyPath, errTemp := l.writeRequestBodyTempFile(body)
|
||||
if errTemp != nil {
|
||||
log.WithError(errTemp).Warn("failed to create request body temp file, falling back to direct write")
|
||||
}
|
||||
if requestBodyPath != "" {
|
||||
defer func() {
|
||||
if errRemove := os.Remove(requestBodyPath); errRemove != nil {
|
||||
log.WithError(errRemove).Warn("failed to remove request body temp file")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Create log content
|
||||
content := l.formatLogContent(url, method, requestHeaders, body, apiRequest, apiResponse, decompressedResponse, statusCode, responseHeaders, apiResponseErrors)
|
||||
responseToWrite, decompressErr := l.decompressResponse(responseHeaders, response)
|
||||
if decompressErr != nil {
|
||||
// If decompression fails, continue with original response and annotate the log output.
|
||||
responseToWrite = response
|
||||
}
|
||||
|
||||
// Write to file
|
||||
if err = os.WriteFile(filePath, []byte(content), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write log file: %w", err)
|
||||
logFile, errOpen := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
if errOpen != nil {
|
||||
return fmt.Errorf("failed to create log file: %w", errOpen)
|
||||
}
|
||||
|
||||
writeErr := l.writeNonStreamingLog(
|
||||
logFile,
|
||||
url,
|
||||
method,
|
||||
requestHeaders,
|
||||
body,
|
||||
requestBodyPath,
|
||||
apiRequest,
|
||||
apiResponse,
|
||||
apiResponseErrors,
|
||||
statusCode,
|
||||
responseHeaders,
|
||||
responseToWrite,
|
||||
decompressErr,
|
||||
)
|
||||
if errClose := logFile.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("failed to close request log file")
|
||||
if writeErr == nil {
|
||||
return errClose
|
||||
}
|
||||
}
|
||||
if writeErr != nil {
|
||||
return fmt.Errorf("failed to write log file: %w", writeErr)
|
||||
}
|
||||
|
||||
if force && !l.enabled {
|
||||
@@ -235,11 +274,12 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
|
||||
// - method: The HTTP method
|
||||
// - headers: The request headers
|
||||
// - body: The request body
|
||||
// - requestID: Optional request ID for log file naming
|
||||
//
|
||||
// Returns:
|
||||
// - StreamingLogWriter: A writer for streaming response chunks
|
||||
// - error: An error if logging initialization fails, nil otherwise
|
||||
func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error) {
|
||||
func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error) {
|
||||
if !l.enabled {
|
||||
return &NoOpStreamingLogWriter{}, nil
|
||||
}
|
||||
@@ -249,30 +289,42 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
|
||||
return nil, fmt.Errorf("failed to create logs directory: %w", err)
|
||||
}
|
||||
|
||||
// Generate filename
|
||||
filename := l.generateFilename(url)
|
||||
// Generate filename with request ID
|
||||
filename := l.generateFilename(url, requestID)
|
||||
filePath := filepath.Join(l.logsDir, filename)
|
||||
|
||||
// Create and open file
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create log file: %w", err)
|
||||
requestHeaders := make(map[string][]string, len(headers))
|
||||
for key, values := range headers {
|
||||
headerValues := make([]string, len(values))
|
||||
copy(headerValues, values)
|
||||
requestHeaders[key] = headerValues
|
||||
}
|
||||
|
||||
// Write initial request information
|
||||
requestInfo := l.formatRequestInfo(url, method, headers, body)
|
||||
if _, err = file.WriteString(requestInfo); err != nil {
|
||||
_ = file.Close()
|
||||
return nil, fmt.Errorf("failed to write request info: %w", err)
|
||||
requestBodyPath, errTemp := l.writeRequestBodyTempFile(body)
|
||||
if errTemp != nil {
|
||||
return nil, fmt.Errorf("failed to create request body temp file: %w", errTemp)
|
||||
}
|
||||
|
||||
responseBodyFile, errCreate := os.CreateTemp(l.logsDir, "response-body-*.tmp")
|
||||
if errCreate != nil {
|
||||
_ = os.Remove(requestBodyPath)
|
||||
return nil, fmt.Errorf("failed to create response body temp file: %w", errCreate)
|
||||
}
|
||||
responseBodyPath := responseBodyFile.Name()
|
||||
|
||||
// Create streaming writer
|
||||
writer := &FileStreamingLogWriter{
|
||||
file: file,
|
||||
chunkChan: make(chan []byte, 100), // Buffered channel for async writes
|
||||
closeChan: make(chan struct{}),
|
||||
errorChan: make(chan error, 1),
|
||||
bufferedChunks: &bytes.Buffer{},
|
||||
logFilePath: filePath,
|
||||
url: url,
|
||||
method: method,
|
||||
timestamp: time.Now(),
|
||||
requestHeaders: requestHeaders,
|
||||
requestBodyPath: requestBodyPath,
|
||||
responseBodyPath: responseBodyPath,
|
||||
responseBodyFile: responseBodyFile,
|
||||
chunkChan: make(chan []byte, 100), // Buffered channel for async writes
|
||||
closeChan: make(chan struct{}),
|
||||
errorChan: make(chan error, 1),
|
||||
}
|
||||
|
||||
// Start async writer goroutine
|
||||
@@ -282,8 +334,8 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
|
||||
}
|
||||
|
||||
// generateErrorFilename creates a filename with an error prefix to differentiate forced error logs.
|
||||
func (l *FileRequestLogger) generateErrorFilename(url string) string {
|
||||
return fmt.Sprintf("error-%s", l.generateFilename(url))
|
||||
func (l *FileRequestLogger) generateErrorFilename(url string, requestID ...string) string {
|
||||
return fmt.Sprintf("error-%s", l.generateFilename(url, requestID...))
|
||||
}
|
||||
|
||||
// ensureLogsDir creates the logs directory if it doesn't exist.
|
||||
@@ -298,13 +350,15 @@ func (l *FileRequestLogger) ensureLogsDir() error {
|
||||
}
|
||||
|
||||
// generateFilename creates a sanitized filename from the URL path and current timestamp.
|
||||
// Format: v1-responses-2025-12-23T195811-a1b2c3d4.log
|
||||
//
|
||||
// Parameters:
|
||||
// - url: The request URL
|
||||
// - requestID: Optional request ID to include in filename
|
||||
//
|
||||
// Returns:
|
||||
// - string: A sanitized filename for the log file
|
||||
func (l *FileRequestLogger) generateFilename(url string) string {
|
||||
func (l *FileRequestLogger) generateFilename(url string, requestID ...string) string {
|
||||
// Extract path from URL
|
||||
path := url
|
||||
if strings.Contains(url, "?") {
|
||||
@@ -320,10 +374,18 @@ func (l *FileRequestLogger) generateFilename(url string) string {
|
||||
sanitized := l.sanitizeForFilename(path)
|
||||
|
||||
// Add timestamp
|
||||
timestamp := time.Now().Format("2006-01-02T150405-.000000000")
|
||||
timestamp = strings.Replace(timestamp, ".", "", -1)
|
||||
timestamp := time.Now().Format("2006-01-02T150405")
|
||||
|
||||
return fmt.Sprintf("%s-%s.log", sanitized, timestamp)
|
||||
// Use request ID if provided, otherwise use sequential ID
|
||||
var idPart string
|
||||
if len(requestID) > 0 && requestID[0] != "" {
|
||||
idPart = requestID[0]
|
||||
} else {
|
||||
id := requestLogID.Add(1)
|
||||
idPart = fmt.Sprintf("%d", id)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s-%s-%s.log", sanitized, timestamp, idPart)
|
||||
}
|
||||
|
||||
// sanitizeForFilename replaces characters that are not safe for filenames.
|
||||
@@ -405,6 +467,220 @@ func (l *FileRequestLogger) cleanupOldErrorLogs() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *FileRequestLogger) writeRequestBodyTempFile(body []byte) (string, error) {
|
||||
tmpFile, errCreate := os.CreateTemp(l.logsDir, "request-body-*.tmp")
|
||||
if errCreate != nil {
|
||||
return "", errCreate
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
if _, errCopy := io.Copy(tmpFile, bytes.NewReader(body)); errCopy != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", errCopy
|
||||
}
|
||||
if errClose := tmpFile.Close(); errClose != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", errClose
|
||||
}
|
||||
return tmpPath, nil
|
||||
}
|
||||
|
||||
func (l *FileRequestLogger) writeNonStreamingLog(
|
||||
w io.Writer,
|
||||
url, method string,
|
||||
requestHeaders map[string][]string,
|
||||
requestBody []byte,
|
||||
requestBodyPath string,
|
||||
apiRequest []byte,
|
||||
apiResponse []byte,
|
||||
apiResponseErrors []*interfaces.ErrorMessage,
|
||||
statusCode int,
|
||||
responseHeaders map[string][]string,
|
||||
response []byte,
|
||||
decompressErr error,
|
||||
) error {
|
||||
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, time.Now()); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPIErrorResponses(w, apiResponseErrors); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
|
||||
}
|
||||
|
||||
func writeRequestInfoWithBody(
|
||||
w io.Writer,
|
||||
url, method string,
|
||||
headers map[string][]string,
|
||||
body []byte,
|
||||
bodyPath string,
|
||||
timestamp time.Time,
|
||||
) error {
|
||||
if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Version: %s\n", buildinfo.Version)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("URL: %s\n", url)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "=== HEADERS ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
for key, values := range headers {
|
||||
for _, value := range values {
|
||||
masked := util.MaskSensitiveHeaderValue(key, value)
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("%s: %s\n", key, masked)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
if bodyPath != "" {
|
||||
bodyFile, errOpen := os.Open(bodyPath)
|
||||
if errOpen != nil {
|
||||
return errOpen
|
||||
}
|
||||
if _, errCopy := io.Copy(w, bodyFile); errCopy != nil {
|
||||
_ = bodyFile.Close()
|
||||
return errCopy
|
||||
}
|
||||
if errClose := bodyFile.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("failed to close request body temp file")
|
||||
}
|
||||
} else if _, errWrite := w.Write(body); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte) error {
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(payload, []byte(sectionPrefix)) {
|
||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if !bytes.HasSuffix(payload, []byte("\n")) {
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeAPIErrorResponses(w io.Writer, apiResponseErrors []*interfaces.ErrorMessage) error {
|
||||
for i := 0; i < len(apiResponseErrors); i++ {
|
||||
if apiResponseErrors[i] == nil {
|
||||
continue
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "=== API ERROR RESPONSE ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if apiResponseErrors[i].Error != nil {
|
||||
if _, errWrite := io.WriteString(w, apiResponseErrors[i].Error.Error()); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, responseHeaders map[string][]string, responseReader io.Reader, decompressErr error, trailingNewline bool) error {
|
||||
if _, errWrite := io.WriteString(w, "=== RESPONSE ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if statusWritten {
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Status: %d\n", statusCode)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
|
||||
if responseHeaders != nil {
|
||||
for key, values := range responseHeaders {
|
||||
for _, value := range values {
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("%s: %s\n", key, value)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
if responseReader != nil {
|
||||
if _, errCopy := io.Copy(w, responseReader); errCopy != nil {
|
||||
return errCopy
|
||||
}
|
||||
}
|
||||
if decompressErr != nil {
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("\n[DECOMPRESSION ERROR: %v]", decompressErr)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
|
||||
if trailingNewline {
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatLogContent creates the complete log content for non-streaming requests.
|
||||
//
|
||||
// Parameters:
|
||||
@@ -648,13 +924,34 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
|
||||
}
|
||||
|
||||
// FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs.
|
||||
// It handles asynchronous writing of streaming response chunks to a file.
|
||||
// All data is buffered and written in the correct order when Close is called.
|
||||
// It spools streaming response chunks to a temporary file to avoid retaining large responses in memory.
|
||||
// The final log file is assembled when Close is called.
|
||||
type FileStreamingLogWriter struct {
|
||||
// file is the file where log data is written.
|
||||
file *os.File
|
||||
// logFilePath is the final log file path.
|
||||
logFilePath string
|
||||
|
||||
// chunkChan is a channel for receiving response chunks to buffer.
|
||||
// url is the request URL (masked upstream in middleware).
|
||||
url string
|
||||
|
||||
// method is the HTTP method.
|
||||
method string
|
||||
|
||||
// timestamp is captured when the streaming log is initialized.
|
||||
timestamp time.Time
|
||||
|
||||
// requestHeaders stores the request headers.
|
||||
requestHeaders map[string][]string
|
||||
|
||||
// requestBodyPath is a temporary file path holding the request body.
|
||||
requestBodyPath string
|
||||
|
||||
// responseBodyPath is a temporary file path holding the streaming response body.
|
||||
responseBodyPath string
|
||||
|
||||
// responseBodyFile is the temp file where chunks are appended by the async writer.
|
||||
responseBodyFile *os.File
|
||||
|
||||
// chunkChan is a channel for receiving response chunks to spool.
|
||||
chunkChan chan []byte
|
||||
|
||||
// closeChan is a channel for signaling when the writer is closed.
|
||||
@@ -663,9 +960,6 @@ type FileStreamingLogWriter struct {
|
||||
// errorChan is a channel for reporting errors during writing.
|
||||
errorChan chan error
|
||||
|
||||
// bufferedChunks stores the response chunks in order.
|
||||
bufferedChunks *bytes.Buffer
|
||||
|
||||
// responseStatus stores the HTTP status code.
|
||||
responseStatus int
|
||||
|
||||
@@ -770,85 +1064,115 @@ func (w *FileStreamingLogWriter) Close() error {
|
||||
close(w.chunkChan)
|
||||
}
|
||||
|
||||
// Wait for async writer to finish buffering chunks
|
||||
// Wait for async writer to finish spooling chunks
|
||||
if w.closeChan != nil {
|
||||
<-w.closeChan
|
||||
w.chunkChan = nil
|
||||
}
|
||||
|
||||
if w.file == nil {
|
||||
select {
|
||||
case errWrite := <-w.errorChan:
|
||||
w.cleanupTempFiles()
|
||||
return errWrite
|
||||
default:
|
||||
}
|
||||
|
||||
if w.logFilePath == "" {
|
||||
w.cleanupTempFiles()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write all content in the correct order
|
||||
var content strings.Builder
|
||||
|
||||
// 1. Write API REQUEST section
|
||||
if len(w.apiRequest) > 0 {
|
||||
if bytes.HasPrefix(w.apiRequest, []byte("=== API REQUEST")) {
|
||||
content.Write(w.apiRequest)
|
||||
if !bytes.HasSuffix(w.apiRequest, []byte("\n")) {
|
||||
content.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
content.WriteString("=== API REQUEST ===\n")
|
||||
content.Write(w.apiRequest)
|
||||
content.WriteString("\n")
|
||||
}
|
||||
content.WriteString("\n")
|
||||
logFile, errOpen := os.OpenFile(w.logFilePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
if errOpen != nil {
|
||||
w.cleanupTempFiles()
|
||||
return fmt.Errorf("failed to create log file: %w", errOpen)
|
||||
}
|
||||
|
||||
// 2. Write API RESPONSE section
|
||||
if len(w.apiResponse) > 0 {
|
||||
if bytes.HasPrefix(w.apiResponse, []byte("=== API RESPONSE")) {
|
||||
content.Write(w.apiResponse)
|
||||
if !bytes.HasSuffix(w.apiResponse, []byte("\n")) {
|
||||
content.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
content.WriteString("=== API RESPONSE ===\n")
|
||||
content.Write(w.apiResponse)
|
||||
content.WriteString("\n")
|
||||
}
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
// 3. Write RESPONSE section (status, headers, buffered chunks)
|
||||
content.WriteString("=== RESPONSE ===\n")
|
||||
if w.statusWritten {
|
||||
content.WriteString(fmt.Sprintf("Status: %d\n", w.responseStatus))
|
||||
}
|
||||
|
||||
for key, values := range w.responseHeaders {
|
||||
for _, value := range values {
|
||||
content.WriteString(fmt.Sprintf("%s: %s\n", key, value))
|
||||
writeErr := w.writeFinalLog(logFile)
|
||||
if errClose := logFile.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("failed to close request log file")
|
||||
if writeErr == nil {
|
||||
writeErr = errClose
|
||||
}
|
||||
}
|
||||
content.WriteString("\n")
|
||||
|
||||
// Write buffered response body chunks
|
||||
if w.bufferedChunks != nil && w.bufferedChunks.Len() > 0 {
|
||||
content.Write(w.bufferedChunks.Bytes())
|
||||
}
|
||||
|
||||
// Write the complete content to file
|
||||
if _, err := w.file.WriteString(content.String()); err != nil {
|
||||
_ = w.file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return w.file.Close()
|
||||
w.cleanupTempFiles()
|
||||
return writeErr
|
||||
}
|
||||
|
||||
// asyncWriter runs in a goroutine to buffer chunks from the channel.
|
||||
// It continuously reads chunks from the channel and buffers them for later writing.
|
||||
// It continuously reads chunks from the channel and appends them to a temp file for later assembly.
|
||||
func (w *FileStreamingLogWriter) asyncWriter() {
|
||||
defer close(w.closeChan)
|
||||
|
||||
for chunk := range w.chunkChan {
|
||||
if w.bufferedChunks != nil {
|
||||
w.bufferedChunks.Write(chunk)
|
||||
if w.responseBodyFile == nil {
|
||||
continue
|
||||
}
|
||||
if _, errWrite := w.responseBodyFile.Write(chunk); errWrite != nil {
|
||||
select {
|
||||
case w.errorChan <- errWrite:
|
||||
default:
|
||||
}
|
||||
if errClose := w.responseBodyFile.Close(); errClose != nil {
|
||||
select {
|
||||
case w.errorChan <- errClose:
|
||||
default:
|
||||
}
|
||||
}
|
||||
w.responseBodyFile = nil
|
||||
}
|
||||
}
|
||||
|
||||
if w.responseBodyFile == nil {
|
||||
return
|
||||
}
|
||||
if errClose := w.responseBodyFile.Close(); errClose != nil {
|
||||
select {
|
||||
case w.errorChan <- errClose:
|
||||
default:
|
||||
}
|
||||
}
|
||||
w.responseBodyFile = nil
|
||||
}
|
||||
|
||||
func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
|
||||
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(logFile, "=== API RESPONSE ===\n", "=== API RESPONSE", w.apiResponse); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
responseBodyFile, errOpen := os.Open(w.responseBodyPath)
|
||||
if errOpen != nil {
|
||||
return errOpen
|
||||
}
|
||||
defer func() {
|
||||
if errClose := responseBodyFile.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("failed to close response body temp file")
|
||||
}
|
||||
}()
|
||||
|
||||
return writeResponseSection(logFile, w.responseStatus, w.statusWritten, w.responseHeaders, responseBodyFile, nil, false)
|
||||
}
|
||||
|
||||
func (w *FileStreamingLogWriter) cleanupTempFiles() {
|
||||
if w.requestBodyPath != "" {
|
||||
if errRemove := os.Remove(w.requestBodyPath); errRemove != nil {
|
||||
log.WithError(errRemove).Warn("failed to remove request body temp file")
|
||||
}
|
||||
w.requestBodyPath = ""
|
||||
}
|
||||
|
||||
if w.responseBodyPath != "" {
|
||||
if errRemove := os.Remove(w.responseBodyPath); errRemove != nil {
|
||||
log.WithError(errRemove).Warn("failed to remove response body temp file")
|
||||
}
|
||||
w.responseBodyPath = ""
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
61
internal/logging/requestid.go
Normal file
61
internal/logging/requestid.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// requestIDKey is the context key for storing/retrieving request IDs.
|
||||
type requestIDKey struct{}
|
||||
|
||||
// ginRequestIDKey is the Gin context key for request IDs.
|
||||
const ginRequestIDKey = "__request_id__"
|
||||
|
||||
// GenerateRequestID creates a new 8-character hex request ID.
|
||||
func GenerateRequestID() string {
|
||||
b := make([]byte, 4)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "00000000"
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// WithRequestID returns a new context with the request ID attached.
|
||||
func WithRequestID(ctx context.Context, requestID string) context.Context {
|
||||
return context.WithValue(ctx, requestIDKey{}, requestID)
|
||||
}
|
||||
|
||||
// GetRequestID retrieves the request ID from the context.
|
||||
// Returns empty string if not found.
|
||||
func GetRequestID(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
if id, ok := ctx.Value(requestIDKey{}).(string); ok {
|
||||
return id
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetGinRequestID stores the request ID in the Gin context.
|
||||
func SetGinRequestID(c *gin.Context, requestID string) {
|
||||
if c != nil {
|
||||
c.Set(ginRequestIDKey, requestID)
|
||||
}
|
||||
}
|
||||
|
||||
// GetGinRequestID retrieves the request ID from the Gin context.
|
||||
func GetGinRequestID(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
if id, exists := c.Get(ginRequestIDKey); exists {
|
||||
if s, ok := id.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -23,10 +24,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
managementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
|
||||
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
)
|
||||
|
||||
// ManagementFileName exposes the control panel asset filename.
|
||||
@@ -97,7 +99,7 @@ func runAutoUpdater(ctx context.Context) {
|
||||
|
||||
configPath, _ := schedulerConfigPath.Load().(string)
|
||||
staticDir := StaticDir(configPath)
|
||||
EnsureLatestManagementHTML(ctx, staticDir, cfg.ProxyURL)
|
||||
EnsureLatestManagementHTML(ctx, staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
}
|
||||
|
||||
runOnce()
|
||||
@@ -181,7 +183,7 @@ func FilePath(configFilePath string) string {
|
||||
// EnsureLatestManagementHTML checks the latest management.html asset and updates the local copy when needed.
|
||||
// The function is designed to run in a background goroutine and will never panic.
|
||||
// It enforces a 3-hour rate limit to avoid frequent checks on config/auth file changes.
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string) {
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
@@ -197,6 +199,16 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
return
|
||||
}
|
||||
|
||||
localPath := filepath.Join(staticDir, managementAssetName)
|
||||
localFileMissing := false
|
||||
if _, errStat := os.Stat(localPath); errStat != nil {
|
||||
if errors.Is(errStat, os.ErrNotExist) {
|
||||
localFileMissing = true
|
||||
} else {
|
||||
log.WithError(errStat).Debug("failed to stat local management asset")
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting: check only once every 3 hours
|
||||
lastUpdateCheckMu.Lock()
|
||||
now := time.Now()
|
||||
@@ -209,14 +221,14 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
lastUpdateCheckTime = now
|
||||
lastUpdateCheckMu.Unlock()
|
||||
|
||||
if err := os.MkdirAll(staticDir, 0o755); err != nil {
|
||||
log.WithError(err).Warn("failed to prepare static directory for management asset")
|
||||
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
|
||||
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
|
||||
return
|
||||
}
|
||||
|
||||
releaseURL := resolveReleaseURL(panelRepository)
|
||||
client := newHTTPClient(proxyURL)
|
||||
|
||||
localPath := filepath.Join(staticDir, managementAssetName)
|
||||
localHash, err := fileSHA256(localPath)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
@@ -225,8 +237,15 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
localHash = ""
|
||||
}
|
||||
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client)
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
log.WithError(err).Warn("failed to fetch latest management release information")
|
||||
return
|
||||
}
|
||||
@@ -238,6 +257,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to download management asset, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
log.WithError(err).Warn("failed to download management asset")
|
||||
return
|
||||
}
|
||||
@@ -254,8 +280,60 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
||||
}
|
||||
|
||||
func fetchLatestAsset(ctx context.Context, client *http.Client) (*releaseAsset, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, managementReleaseURL, nil)
|
||||
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, defaultManagementFallbackURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("failed to download fallback management control panel page")
|
||||
return false
|
||||
}
|
||||
|
||||
if err = atomicWriteFile(localPath, data); err != nil {
|
||||
log.WithError(err).Warn("failed to persist fallback management control panel page")
|
||||
return false
|
||||
}
|
||||
|
||||
log.Infof("management asset updated from fallback page successfully (hash=%s)", downloadedHash)
|
||||
return true
|
||||
}
|
||||
|
||||
func resolveReleaseURL(repo string) string {
|
||||
repo = strings.TrimSpace(repo)
|
||||
if repo == "" {
|
||||
return defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(repo)
|
||||
if err != nil || parsed.Host == "" {
|
||||
return defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
host := strings.ToLower(parsed.Host)
|
||||
parsed.Path = strings.TrimSuffix(parsed.Path, "/")
|
||||
|
||||
if host == "api.github.com" {
|
||||
if !strings.HasSuffix(strings.ToLower(parsed.Path), "/releases/latest") {
|
||||
parsed.Path = parsed.Path + "/releases/latest"
|
||||
}
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
if host == "github.com" {
|
||||
parts := strings.Split(strings.Trim(parsed.Path, "/"), "/")
|
||||
if len(parts) >= 2 && parts[0] != "" && parts[1] != "" {
|
||||
repoName := strings.TrimSuffix(parts[1], ".git")
|
||||
return fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", parts[0], repoName)
|
||||
}
|
||||
}
|
||||
|
||||
return defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
func fetchLatestAsset(ctx context.Context, client *http.Client, releaseURL string) (*releaseAsset, string, error) {
|
||||
if strings.TrimSpace(releaseURL) == "" {
|
||||
releaseURL = defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, releaseURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("create release request: %w", err)
|
||||
}
|
||||
|
||||
@@ -7,12 +7,93 @@ import (
|
||||
"embed"
|
||||
_ "embed"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// codexInstructionsEnabled controls whether CodexInstructionsForModel returns official instructions.
|
||||
// When false (default), CodexInstructionsForModel returns (true, "") immediately.
|
||||
// Set via SetCodexInstructionsEnabled from config.
|
||||
var codexInstructionsEnabled atomic.Bool
|
||||
|
||||
// SetCodexInstructionsEnabled sets whether codex instructions processing is enabled.
|
||||
func SetCodexInstructionsEnabled(enabled bool) {
|
||||
codexInstructionsEnabled.Store(enabled)
|
||||
}
|
||||
|
||||
// GetCodexInstructionsEnabled returns whether codex instructions processing is enabled.
|
||||
func GetCodexInstructionsEnabled() bool {
|
||||
return codexInstructionsEnabled.Load()
|
||||
}
|
||||
|
||||
//go:embed codex_instructions
|
||||
var codexInstructionsDir embed.FS
|
||||
|
||||
func CodexInstructionsForModel(modelName, systemInstructions string) (bool, string) {
|
||||
//go:embed opencode_codex_instructions.txt
|
||||
var opencodeCodexInstructions string
|
||||
|
||||
const (
|
||||
codexUserAgentKey = "__cpa_user_agent"
|
||||
userAgentOpenAISDK = "ai-sdk/openai/"
|
||||
)
|
||||
|
||||
func InjectCodexUserAgent(raw []byte, userAgent string) []byte {
|
||||
if len(raw) == 0 {
|
||||
return raw
|
||||
}
|
||||
trimmed := strings.TrimSpace(userAgent)
|
||||
if trimmed == "" {
|
||||
return raw
|
||||
}
|
||||
updated, err := sjson.SetBytes(raw, codexUserAgentKey, trimmed)
|
||||
if err != nil {
|
||||
return raw
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
func ExtractCodexUserAgent(raw []byte) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(gjson.GetBytes(raw, codexUserAgentKey).String())
|
||||
}
|
||||
|
||||
func StripCodexUserAgent(raw []byte) []byte {
|
||||
if len(raw) == 0 {
|
||||
return raw
|
||||
}
|
||||
if !gjson.GetBytes(raw, codexUserAgentKey).Exists() {
|
||||
return raw
|
||||
}
|
||||
updated, err := sjson.DeleteBytes(raw, codexUserAgentKey)
|
||||
if err != nil {
|
||||
return raw
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
func codexInstructionsForOpenCode(systemInstructions string) (bool, string) {
|
||||
if opencodeCodexInstructions == "" {
|
||||
return false, ""
|
||||
}
|
||||
if strings.HasPrefix(systemInstructions, opencodeCodexInstructions) {
|
||||
return true, ""
|
||||
}
|
||||
return false, opencodeCodexInstructions
|
||||
}
|
||||
|
||||
func useOpenCodeInstructions(userAgent string) bool {
|
||||
return strings.Contains(strings.ToLower(userAgent), userAgentOpenAISDK)
|
||||
}
|
||||
|
||||
func IsOpenCodeUserAgent(userAgent string) bool {
|
||||
return useOpenCodeInstructions(userAgent)
|
||||
}
|
||||
|
||||
func codexInstructionsForCodex(modelName, systemInstructions string) (bool, string) {
|
||||
entries, _ := codexInstructionsDir.ReadDir("codex_instructions")
|
||||
|
||||
lastPrompt := ""
|
||||
@@ -20,6 +101,7 @@ func CodexInstructionsForModel(modelName, systemInstructions string) (bool, stri
|
||||
lastCodexMaxPrompt := ""
|
||||
last51Prompt := ""
|
||||
last52Prompt := ""
|
||||
last52CodexPrompt := ""
|
||||
// lastReviewPrompt := ""
|
||||
for _, entry := range entries {
|
||||
content, _ := codexInstructionsDir.ReadFile("codex_instructions/" + entry.Name())
|
||||
@@ -36,12 +118,16 @@ func CodexInstructionsForModel(modelName, systemInstructions string) (bool, stri
|
||||
last51Prompt = string(content)
|
||||
} else if strings.HasPrefix(entry.Name(), "gpt_5_2_prompt.md") {
|
||||
last52Prompt = string(content)
|
||||
} else if strings.HasPrefix(entry.Name(), "gpt-5.2-codex_prompt.md") {
|
||||
last52CodexPrompt = string(content)
|
||||
} else if strings.HasPrefix(entry.Name(), "review_prompt.md") {
|
||||
// lastReviewPrompt = string(content)
|
||||
}
|
||||
}
|
||||
if strings.Contains(modelName, "codex-max") {
|
||||
return false, lastCodexMaxPrompt
|
||||
} else if strings.Contains(modelName, "5.2-codex") {
|
||||
return false, last52CodexPrompt
|
||||
} else if strings.Contains(modelName, "codex") {
|
||||
return false, lastCodexPrompt
|
||||
} else if strings.Contains(modelName, "5.1") {
|
||||
@@ -52,3 +138,13 @@ func CodexInstructionsForModel(modelName, systemInstructions string) (bool, stri
|
||||
return false, lastPrompt
|
||||
}
|
||||
}
|
||||
|
||||
func CodexInstructionsForModel(modelName, systemInstructions, userAgent string) (bool, string) {
|
||||
if !GetCodexInstructionsEnabled() {
|
||||
return true, ""
|
||||
}
|
||||
if IsOpenCodeUserAgent(userAgent) {
|
||||
return codexInstructionsForOpenCode(systemInstructions)
|
||||
}
|
||||
return codexInstructionsForCodex(modelName, systemInstructions)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
|
||||
|
||||
## General
|
||||
|
||||
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
|
||||
|
||||
## Editing constraints
|
||||
|
||||
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
|
||||
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like "Assigns the value to the variable", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
|
||||
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
|
||||
- You may be in a dirty git worktree.
|
||||
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
|
||||
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
|
||||
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
|
||||
* If the changes are in unrelated files, just ignore them and don't revert them.
|
||||
- Do not amend a commit unless explicitly requested to do so.
|
||||
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
|
||||
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
|
||||
|
||||
## Plan tool
|
||||
|
||||
When using the planning tool:
|
||||
- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
|
||||
- Do not make single-step plans.
|
||||
- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
|
||||
|
||||
## Codex CLI harness, sandboxing, and approvals
|
||||
|
||||
The Codex CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from.
|
||||
|
||||
Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are:
|
||||
- **read-only**: The sandbox only permits reading files.
|
||||
- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval.
|
||||
- **danger-full-access**: No filesystem sandboxing - all commands are permitted.
|
||||
|
||||
Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are:
|
||||
- **restricted**: Requires approval
|
||||
- **enabled**: No approval needed
|
||||
|
||||
Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are
|
||||
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe "read" commands.
|
||||
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
|
||||
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
|
||||
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
|
||||
|
||||
When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
|
||||
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var)
|
||||
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
|
||||
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
|
||||
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `sandbox_permissions` and `justification` parameters - do not message the user before requesting approval for the command.
|
||||
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
|
||||
- (for all of these, you should weigh alternative paths that do not require approval)
|
||||
|
||||
When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read.
|
||||
|
||||
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure.
|
||||
|
||||
Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to "never", in which case never ask for approvals.
|
||||
|
||||
When requesting approval to execute a command that will require escalated privileges:
|
||||
- Provide the `sandbox_permissions` parameter with the value `"require_escalated"`
|
||||
- Include a short, 1 sentence explanation for why you need escalated permissions in the justification parameter
|
||||
|
||||
## Special user requests
|
||||
|
||||
- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
|
||||
- If the user asks for a "review", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
|
||||
|
||||
## Frontend tasks
|
||||
When doing frontend design tasks, avoid collapsing into "AI slop" or safe, average-looking layouts.
|
||||
Aim for interfaces that feel intentional, bold, and a bit surprising.
|
||||
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
|
||||
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
|
||||
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
|
||||
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
|
||||
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
|
||||
- Ensure the page loads properly on both desktop and mobile
|
||||
|
||||
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
|
||||
|
||||
## Presenting your work and final message
|
||||
|
||||
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
|
||||
|
||||
- Default: be very concise; friendly coding teammate tone.
|
||||
- Ask only when needed; suggest ideas; mirror the user's style.
|
||||
- For substantial work, summarize clearly; follow final‑answer formatting.
|
||||
- Skip heavy formatting for simple confirmations.
|
||||
- Don't dump large files you've written; reference paths only.
|
||||
- No "save/copy this file" - User is on the same machine.
|
||||
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
|
||||
- For code changes:
|
||||
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with "summary", just jump right in.
|
||||
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
|
||||
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
|
||||
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
|
||||
|
||||
### Final answer structure and style guidelines
|
||||
|
||||
- Plain text; CLI handles styling. Use structure only when it helps scanability.
|
||||
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
|
||||
- Bullets: use - ; merge related points; keep to one line when possible; 4–6 per list ordered by importance; keep phrasing consistent.
|
||||
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
|
||||
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
|
||||
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
|
||||
- Tone: collaborative, concise, factual; present tense, active voice; self‑contained; no "above/below"; parallel wording.
|
||||
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
|
||||
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
|
||||
- File References: When referencing files in your response follow the below rules:
|
||||
* Use inline code to make file paths clickable.
|
||||
* Each reference should have a stand alone path. Even if it's the same file.
|
||||
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
||||
* Optionally include line/column (1‑based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
||||
* Do not use URIs like file://, vscode://, or https://.
|
||||
* Do not provide range of lines
|
||||
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GenerateRandomState generates a cryptographically secure random state parameter
|
||||
@@ -19,3 +21,83 @@ func GenerateRandomState() (string, error) {
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// OAuthCallback captures the parsed OAuth callback parameters.
|
||||
type OAuthCallback struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
ErrorDescription string
|
||||
}
|
||||
|
||||
// ParseOAuthCallback extracts OAuth parameters from a callback URL.
|
||||
// It returns nil when the input is empty.
|
||||
func ParseOAuthCallback(input string) (*OAuthCallback, error) {
|
||||
trimmed := strings.TrimSpace(input)
|
||||
if trimmed == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
candidate := trimmed
|
||||
if !strings.Contains(candidate, "://") {
|
||||
if strings.HasPrefix(candidate, "?") {
|
||||
candidate = "http://localhost" + candidate
|
||||
} else if strings.ContainsAny(candidate, "/?#") || strings.Contains(candidate, ":") {
|
||||
candidate = "http://" + candidate
|
||||
} else if strings.Contains(candidate, "=") {
|
||||
candidate = "http://localhost/?" + candidate
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid callback URL")
|
||||
}
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(candidate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
code := strings.TrimSpace(query.Get("code"))
|
||||
state := strings.TrimSpace(query.Get("state"))
|
||||
errCode := strings.TrimSpace(query.Get("error"))
|
||||
errDesc := strings.TrimSpace(query.Get("error_description"))
|
||||
|
||||
if parsedURL.Fragment != "" {
|
||||
if fragQuery, errFrag := url.ParseQuery(parsedURL.Fragment); errFrag == nil {
|
||||
if code == "" {
|
||||
code = strings.TrimSpace(fragQuery.Get("code"))
|
||||
}
|
||||
if state == "" {
|
||||
state = strings.TrimSpace(fragQuery.Get("state"))
|
||||
}
|
||||
if errCode == "" {
|
||||
errCode = strings.TrimSpace(fragQuery.Get("error"))
|
||||
}
|
||||
if errDesc == "" {
|
||||
errDesc = strings.TrimSpace(fragQuery.Get("error_description"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if code != "" && state == "" && strings.Contains(code, "#") {
|
||||
parts := strings.SplitN(code, "#", 2)
|
||||
code = parts[0]
|
||||
state = parts[1]
|
||||
}
|
||||
|
||||
if errCode == "" && errDesc != "" {
|
||||
errCode = errDesc
|
||||
errDesc = ""
|
||||
}
|
||||
|
||||
if code == "" && errCode == "" {
|
||||
return nil, fmt.Errorf("callback URL missing code")
|
||||
}
|
||||
|
||||
return &OAuthCallback{
|
||||
Code: code,
|
||||
State: state,
|
||||
Error: errCode,
|
||||
ErrorDescription: errDesc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
318
internal/misc/opencode_codex_instructions.txt
Normal file
318
internal/misc/opencode_codex_instructions.txt
Normal file
@@ -0,0 +1,318 @@
|
||||
You are a coding agent running in the opencode, a terminal-based coding assistant. opencode is an open source project. You are expected to be precise, safe, and helpful.
|
||||
|
||||
Your capabilities:
|
||||
|
||||
- Receive user prompts and other context provided by the harness, such as files in the workspace.
|
||||
- Communicate with the user by streaming thinking & responses, and by making & updating plans.
|
||||
- Emit function calls to run terminal commands and apply edits. Depending on how this specific run is configured, you can request that these function calls be escalated to the user for approval before running. More on this in the "Sandbox and approvals" section.
|
||||
|
||||
Within this context, Codex refers to the open-source agentic coding interface (not the old Codex language model built by OpenAI).
|
||||
|
||||
# How you work
|
||||
|
||||
## Personality
|
||||
|
||||
Your default personality and tone is concise, direct, and friendly. You communicate efficiently, always keeping the user clearly informed about ongoing actions without unnecessary detail. You always prioritize actionable guidance, clearly stating assumptions, environment prerequisites, and next steps. Unless explicitly asked, you avoid excessively verbose explanations about your work.
|
||||
|
||||
# AGENTS.md spec
|
||||
- Repos often contain AGENTS.md files. These files can appear anywhere within the repository.
|
||||
- These files are a way for humans to give you (the agent) instructions or tips for working within the container.
|
||||
- Some examples might be: coding conventions, info about how code is organized, or instructions for how to run or test code.
|
||||
- Instructions in AGENTS.md files:
|
||||
- The scope of an AGENTS.md file is the entire directory tree rooted at the folder that contains it.
|
||||
- For every file you touch in the final patch, you must obey instructions in any AGENTS.md file whose scope includes that file.
|
||||
- Instructions about code style, structure, naming, etc. apply only to code within the AGENTS.md file's scope, unless the file states otherwise.
|
||||
- More-deeply-nested AGENTS.md files take precedence in the case of conflicting instructions.
|
||||
- Direct system/developer/user instructions (as part of a prompt) take precedence over AGENTS.md instructions.
|
||||
- The contents of the AGENTS.md file at the root of the repo and any directories from the CWD up to the root are included with the developer message and don't need to be re-read. When working in a subdirectory of CWD, or a directory outside the CWD, check for any AGENTS.md files that may be applicable.
|
||||
|
||||
## Responsiveness
|
||||
|
||||
### Preamble messages
|
||||
|
||||
Before making tool calls, send a brief preamble to the user explaining what you’re about to do. When sending preamble messages, follow these principles and examples:
|
||||
|
||||
- **Logically group related actions**: if you’re about to run several related commands, describe them together in one preamble rather than sending a separate note for each.
|
||||
- **Keep it concise**: be no more than 1-2 sentences, focused on immediate, tangible next steps. (8–12 words for quick updates).
|
||||
- **Build on prior context**: if this is not your first tool call, use the preamble message to connect the dots with what’s been done so far and create a sense of momentum and clarity for the user to understand your next actions.
|
||||
- **Keep your tone light, friendly and curious**: add small touches of personality in preambles feel collaborative and engaging.
|
||||
- **Exception**: Avoid adding a preamble for every trivial read (e.g., `cat` a single file) unless it’s part of a larger grouped action.
|
||||
|
||||
**Examples:**
|
||||
|
||||
- “I’ve explored the repo; now checking the API route definitions.”
|
||||
- “Next, I’ll patch the config and update the related tests.”
|
||||
- “I’m about to scaffold the CLI commands and helper functions.”
|
||||
- “Ok cool, so I’ve wrapped my head around the repo. Now digging into the API routes.”
|
||||
- “Config’s looking tidy. Next up is editing helpers to keep things in sync.”
|
||||
- “Finished poking at the DB gateway. I will now chase down error handling.”
|
||||
- “Alright, build pipeline order is interesting. Checking how it reports failures.”
|
||||
- “Spotted a clever caching util; now hunting where it gets used.”
|
||||
|
||||
## Planning
|
||||
|
||||
You have access to an `todowrite` tool which tracks steps and progress and renders them to the user. Using the tool helps demonstrate that you've understood the task and convey how you're approaching it. Plans can help to make complex, ambiguous, or multi-phase work clearer and more collaborative for the user. A good plan should break the task into meaningful, logically ordered steps that are easy to verify as you go.
|
||||
|
||||
Note that plans are not for padding out simple work with filler steps or stating the obvious. The content of your plan should not involve doing anything that you aren't capable of doing (i.e. don't try to test things that you can't test). Do not use plans for simple or single-step queries that you can just do or answer immediately.
|
||||
|
||||
Do not repeat the full contents of the plan after an `todowrite` call — the harness already displays it. Instead, summarize the change made and highlight any important context or next step.
|
||||
|
||||
Before running a command, consider whether or not you have completed the
|
||||
previous step, and make sure to mark it as completed before moving on to the
|
||||
next step. It may be the case that you complete all steps in your plan after a
|
||||
single pass of implementation. If this is the case, you can simply mark all the
|
||||
planned steps as completed. Sometimes, you may need to change plans in the
|
||||
middle of a task: call `todowrite` with the updated plan and make sure to provide an `explanation` of the rationale when doing so.
|
||||
|
||||
Use a plan when:
|
||||
|
||||
- The task is non-trivial and will require multiple actions over a long time horizon.
|
||||
- There are logical phases or dependencies where sequencing matters.
|
||||
- The work has ambiguity that benefits from outlining high-level goals.
|
||||
- You want intermediate checkpoints for feedback and validation.
|
||||
- When the user asked you to do more than one thing in a single prompt
|
||||
- The user has asked you to use the plan tool (aka "TODOs")
|
||||
- You generate additional steps while working, and plan to do them before yielding to the user
|
||||
|
||||
### Examples
|
||||
|
||||
**High-quality plans**
|
||||
|
||||
Example 1:
|
||||
|
||||
1. Add CLI entry with file args
|
||||
2. Parse Markdown via CommonMark library
|
||||
3. Apply semantic HTML template
|
||||
4. Handle code blocks, images, links
|
||||
5. Add error handling for invalid files
|
||||
|
||||
Example 2:
|
||||
|
||||
1. Define CSS variables for colors
|
||||
2. Add toggle with localStorage state
|
||||
3. Refactor components to use variables
|
||||
4. Verify all views for readability
|
||||
5. Add smooth theme-change transition
|
||||
|
||||
Example 3:
|
||||
|
||||
1. Set up Node.js + WebSocket server
|
||||
2. Add join/leave broadcast events
|
||||
3. Implement messaging with timestamps
|
||||
4. Add usernames + mention highlighting
|
||||
5. Persist messages in lightweight DB
|
||||
6. Add typing indicators + unread count
|
||||
|
||||
**Low-quality plans**
|
||||
|
||||
Example 1:
|
||||
|
||||
1. Create CLI tool
|
||||
2. Add Markdown parser
|
||||
3. Convert to HTML
|
||||
|
||||
Example 2:
|
||||
|
||||
1. Add dark mode toggle
|
||||
2. Save preference
|
||||
3. Make styles look good
|
||||
|
||||
Example 3:
|
||||
|
||||
1. Create single-file HTML game
|
||||
2. Run quick sanity check
|
||||
3. Summarize usage instructions
|
||||
|
||||
If you need to write a plan, only write high quality plans, not low quality ones.
|
||||
|
||||
## Task execution
|
||||
|
||||
You are a coding agent. Please keep going until the query is completely resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the problem is solved. Autonomously resolve the query to the best of your ability, using the tools available to you, before coming back to the user. Do NOT guess or make up an answer.
|
||||
|
||||
You MUST adhere to the following criteria when solving queries:
|
||||
|
||||
- Working on the repo(s) in the current environment is allowed, even if they are proprietary.
|
||||
- Analyzing code for vulnerabilities is allowed.
|
||||
- Showing user code and tool call details is allowed.
|
||||
- Use the `edit` tool to edit files
|
||||
|
||||
If completing the user's task requires writing or modifying files, your code and final answer should follow these coding guidelines, though user instructions (i.e. AGENTS.md) may override these guidelines:
|
||||
|
||||
- Fix the problem at the root cause rather than applying surface-level patches, when possible.
|
||||
- Avoid unneeded complexity in your solution.
|
||||
- Do not attempt to fix unrelated bugs or broken tests. It is not your responsibility to fix them. (You may mention them to the user in your final message though.)
|
||||
- Update documentation as necessary.
|
||||
- Keep changes consistent with the style of the existing codebase. Changes should be minimal and focused on the task.
|
||||
- Use `git log` and `git blame` to search the history of the codebase if additional context is required.
|
||||
- NEVER add copyright or license headers unless specifically requested.
|
||||
- Do not waste tokens by re-reading files after calling `edit` on them. The tool call will fail if it didn't work. The same goes for making folders, deleting folders, etc.
|
||||
- Do not `git commit` your changes or create new git branches unless explicitly requested.
|
||||
- Do not add inline comments within code unless explicitly requested.
|
||||
- Do not use one-letter variable names unless explicitly requested.
|
||||
- NEVER output inline citations like "【F:README.md†L5-L14】" in your outputs. The CLI is not able to render these so they will just be broken in the UI. Instead, if you output valid filepaths, users will be able to click on them to open the files in their editor.
|
||||
|
||||
## Sandbox and approvals
|
||||
|
||||
The Codex CLI harness supports several different sandboxing, and approval configurations that the user can choose from.
|
||||
|
||||
Filesystem sandboxing prevents you from editing files without user approval. The options are:
|
||||
|
||||
- **read-only**: You can only read files.
|
||||
- **workspace-write**: You can read files. You can write to files in your workspace folder, but not outside it.
|
||||
- **danger-full-access**: No filesystem sandboxing.
|
||||
|
||||
Network sandboxing prevents you from accessing network without approval. Options are
|
||||
|
||||
- **restricted**
|
||||
- **enabled**
|
||||
|
||||
Approvals are your mechanism to get user consent to perform more privileged actions. Although they introduce friction to the user because your work is paused until the user responds, you should leverage them to accomplish your important work. Do not let these settings or the sandbox deter you from attempting to accomplish the user's task. Approval options are
|
||||
|
||||
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe "read" commands.
|
||||
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
|
||||
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
|
||||
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is pared with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
|
||||
|
||||
When you are running with approvals `on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
|
||||
|
||||
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /tmp)
|
||||
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
|
||||
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
|
||||
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval.
|
||||
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
|
||||
- (For all of these, you should weigh alternative paths that do not require approval.)
|
||||
|
||||
Note that when sandboxing is set to read-only, you'll need to request approval for any command that isn't a read.
|
||||
|
||||
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing ON, and approval on-failure.
|
||||
|
||||
## Validating your work
|
||||
|
||||
If the codebase has tests or the ability to build or run, consider using them to verify that your work is complete.
|
||||
|
||||
When testing, your philosophy should be to start as specific as possible to the code you changed so that you can catch issues efficiently, then make your way to broader tests as you build confidence. If there's no test for the code you changed, and if the adjacent patterns in the codebases show that there's a logical place for you to add a test, you may do so. However, do not add tests to codebases with no tests.
|
||||
|
||||
Similarly, once you're confident in correctness, you can suggest or use formatting commands to ensure that your code is well formatted. If there are issues you can iterate up to 3 times to get formatting right, but if you still can't manage it's better to save the user time and present them a correct solution where you call out the formatting in your final message. If the codebase does not have a formatter configured, do not add one.
|
||||
|
||||
For all of testing, running, building, and formatting, do not attempt to fix unrelated bugs. It is not your responsibility to fix them. (You may mention them to the user in your final message though.)
|
||||
|
||||
Be mindful of whether to run validation commands proactively. In the absence of behavioral guidance:
|
||||
|
||||
- When running in non-interactive approval modes like **never** or **on-failure**, proactively run tests, lint and do whatever you need to ensure you've completed the task.
|
||||
- When working in interactive approval modes like **untrusted**, or **on-request**, hold off on running tests or lint commands until the user is ready for you to finalize your output, because these commands take time to run and slow down iteration. Instead suggest what you want to do next, and let the user confirm first.
|
||||
- When working on test-related tasks, such as adding tests, fixing tests, or reproducing a bug to verify behavior, you may proactively run tests regardless of approval mode. Use your judgement to decide whether this is a test-related task.
|
||||
|
||||
## Ambition vs. precision
|
||||
|
||||
For tasks that have no prior context (i.e. the user is starting something brand new), you should feel free to be ambitious and demonstrate creativity with your implementation.
|
||||
|
||||
If you're operating in an existing codebase, you should make sure you do exactly what the user asks with surgical precision. Treat the surrounding codebase with respect, and don't overstep (i.e. changing filenames or variables unnecessarily). You should balance being sufficiently ambitious and proactive when completing tasks of this nature.
|
||||
|
||||
You should use judicious initiative to decide on the right level of detail and complexity to deliver based on the user's needs. This means showing good judgment that you're capable of doing the right extras without gold-plating. This might be demonstrated by high-value, creative touches when scope of the task is vague; while being surgical and targeted when scope is tightly specified.
|
||||
|
||||
## Sharing progress updates
|
||||
|
||||
For especially longer tasks that you work on (i.e. requiring many tool calls, or a plan with multiple steps), you should provide progress updates back to the user at reasonable intervals. These updates should be structured as a concise sentence or two (no more than 8-10 words long) recapping progress so far in plain language: this update demonstrates your understanding of what needs to be done, progress so far (i.e. files explores, subtasks complete), and where you're going next.
|
||||
|
||||
Before doing large chunks of work that may incur latency as experienced by the user (i.e. writing a new file), you should send a concise message to the user with an update indicating what you're about to do to ensure they know what you're spending time on. Don't start editing or writing large files before informing the user what you are doing and why.
|
||||
|
||||
The messages you send before tool calls should describe what is immediately about to be done next in very concise language. If there was previous work done, this preamble message should also include a note about the work done so far to bring the user along.
|
||||
|
||||
## Presenting your work and final message
|
||||
|
||||
Your final message should read naturally, like an update from a concise teammate. For casual conversation, brainstorming tasks, or quick questions from the user, respond in a friendly, conversational tone. You should ask questions, suggest ideas, and adapt to the user’s style. If you've finished a large amount of work, when describing what you've done to the user, you should follow the final answer formatting guidelines to communicate substantive changes. You don't need to add structured formatting for one-word answers, greetings, or purely conversational exchanges.
|
||||
|
||||
You can skip heavy formatting for single, simple actions or confirmations. In these cases, respond in plain sentences with any relevant next step or quick option. Reserve multisection structured responses for results that need grouping or explanation.
|
||||
|
||||
The user is working on the same computer as you, and has access to your work. As such there's no need to show the full contents of large files you have already written unless the user explicitly asks for them. Similarly, if you've created or modified files using `edit`, there's no need to tell users to "save the file" or "copy the code into a file"—just reference the file path.
|
||||
|
||||
If there's something that you think you could help with as a logical next step, concisely ask the user if they want you to do so. Good examples of this are running tests, committing changes, or building out the next logical component. If there’s something that you couldn't do (even with approval) but that the user might want to do (such as verifying changes by running the app), include those instructions succinctly.
|
||||
|
||||
Brevity is very important as a default. You should be very concise (i.e. no more than 10 lines), but can relax this requirement for tasks where additional detail and comprehensiveness is important for the user's understanding.
|
||||
|
||||
### Final answer structure and style guidelines
|
||||
|
||||
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
|
||||
|
||||
**Section Headers**
|
||||
|
||||
- Use only when they improve clarity — they are not mandatory for every answer.
|
||||
- Choose descriptive names that fit the content
|
||||
- Keep headers short (1–3 words) and in `**Title Case**`. Always start headers with `**` and end with `**`
|
||||
- Leave no blank line before the first bullet under a header.
|
||||
- Section headers should only be used where they genuinely improve scannability; avoid fragmenting the answer.
|
||||
|
||||
**Bullets**
|
||||
|
||||
- Use `-` followed by a space for every bullet.
|
||||
- Merge related points when possible; avoid a bullet for every trivial detail.
|
||||
- Keep bullets to one line unless breaking for clarity is unavoidable.
|
||||
- Group into short lists (4–6 bullets) ordered by importance.
|
||||
- Use consistent keyword phrasing and formatting across sections.
|
||||
|
||||
**Monospace**
|
||||
|
||||
- Wrap all commands, file paths, env vars, and code identifiers in backticks (`` `...` ``).
|
||||
- Apply to inline examples and to bullet keywords if the keyword itself is a literal file/command.
|
||||
- Never mix monospace and bold markers; choose one based on whether it’s a keyword (`**`) or inline code/path (`` ` ``).
|
||||
|
||||
**File References**
|
||||
When referencing files in your response, make sure to include the relevant start line and always follow the below rules:
|
||||
* Use inline code to make file paths clickable.
|
||||
* Each reference should have a standalone path. Even if it's the same file.
|
||||
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
||||
* Line/column (1‑based, optional): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
||||
* Do not use URIs like file://, vscode://, or https://.
|
||||
* Do not provide range of lines
|
||||
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5
|
||||
|
||||
**Structure**
|
||||
|
||||
- Place related bullets together; don’t mix unrelated concepts in the same section.
|
||||
- Order sections from general → specific → supporting info.
|
||||
- For subsections (e.g., “Binaries” under “Rust Workspace”), introduce with a bolded keyword bullet, then list items under it.
|
||||
- Match structure to complexity:
|
||||
- Multi-part or detailed results → use clear headers and grouped bullets.
|
||||
- Simple results → minimal headers, possibly just a short list or paragraph.
|
||||
|
||||
**Tone**
|
||||
|
||||
- Keep the voice collaborative and natural, like a coding partner handing off work.
|
||||
- Be concise and factual — no filler or conversational commentary and avoid unnecessary repetition
|
||||
- Use present tense and active voice (e.g., “Runs tests” not “This will run tests”).
|
||||
- Keep descriptions self-contained; don’t refer to “above” or “below”.
|
||||
- Use parallel structure in lists for consistency.
|
||||
|
||||
**Don’t**
|
||||
|
||||
- Don’t use literal words “bold” or “monospace” in the content.
|
||||
- Don’t nest bullets or create deep hierarchies.
|
||||
- Don’t output ANSI escape codes directly — the CLI renderer applies them.
|
||||
- Don’t cram unrelated keywords into a single bullet; split for clarity.
|
||||
- Don’t let keyword lists run long — wrap or reformat for scannability.
|
||||
|
||||
Generally, ensure your final answers adapt their shape and depth to the request. For example, answers to code explanations should have a precise, structured explanation with code references that answer the question directly. For tasks with a simple implementation, lead with the outcome and supplement only with what’s needed for clarity. Larger changes can be presented as a logical walkthrough of your approach, grouping related steps, explaining rationale where it adds value, and highlighting next actions to accelerate the user. Your answers should provide the right level of detail while being easily scannable.
|
||||
|
||||
For casual greetings, acknowledgements, or other one-off conversational messages that are not delivering substantive information or structured results, respond naturally without section headers or bullet formatting.
|
||||
|
||||
# Tool Guidelines
|
||||
|
||||
## Shell commands
|
||||
|
||||
When using the shell, you must adhere to the following guidelines:
|
||||
|
||||
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
|
||||
- Read files in chunks with a max chunk size of 250 lines. Do not use python scripts to attempt to output larger chunks of a file. Command line output will be truncated after 10 kilobytes or 256 lines of output, regardless of the command used.
|
||||
|
||||
## `todowrite`
|
||||
|
||||
A tool named `todowrite` is available to you. You can use it to keep an up‑to‑date, step‑by‑step plan for the task.
|
||||
|
||||
To create a new plan, call `todowrite` with a short list of 1‑sentence steps (no more than 5-7 words each) with a `status` for each step (`pending`, `in_progress`, or `completed`).
|
||||
|
||||
When steps have been completed, use `todowrite` to mark each finished step as
|
||||
`completed` and the next step you are working on as `in_progress`. There should
|
||||
always be exactly one `in_progress` step until everything is done. You can mark
|
||||
multiple items as complete in a single `todowrite` call.
|
||||
|
||||
If all steps are complete, ensure you call `todowrite` to mark all steps as `completed`.
|
||||
303
internal/registry/kiro_model_converter.go
Normal file
303
internal/registry/kiro_model_converter.go
Normal file
@@ -0,0 +1,303 @@
|
||||
// Package registry provides Kiro model conversion utilities.
|
||||
// This file handles converting dynamic Kiro API model lists to the internal ModelInfo format,
|
||||
// and merging with static metadata for thinking support and other capabilities.
|
||||
package registry
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// KiroAPIModel represents a model from Kiro API response.
|
||||
// This is a local copy to avoid import cycles with the kiro package.
|
||||
// The structure mirrors kiro.KiroModel for easy data conversion.
|
||||
type KiroAPIModel struct {
|
||||
// ModelID is the unique identifier for the model (e.g., "claude-sonnet-4.5")
|
||||
ModelID string
|
||||
// ModelName is the human-readable name
|
||||
ModelName string
|
||||
// Description is the model description
|
||||
Description string
|
||||
// RateMultiplier is the credit multiplier for this model
|
||||
RateMultiplier float64
|
||||
// RateUnit is the unit for rate calculation (e.g., "credit")
|
||||
RateUnit string
|
||||
// MaxInputTokens is the maximum input token limit
|
||||
MaxInputTokens int
|
||||
}
|
||||
|
||||
// DefaultKiroThinkingSupport defines the default thinking configuration for Kiro models.
|
||||
// All Kiro models support thinking with the following budget range.
|
||||
var DefaultKiroThinkingSupport = &ThinkingSupport{
|
||||
Min: 1024, // Minimum thinking budget tokens
|
||||
Max: 32000, // Maximum thinking budget tokens
|
||||
ZeroAllowed: true, // Allow disabling thinking with 0
|
||||
DynamicAllowed: true, // Allow dynamic thinking budget (-1)
|
||||
}
|
||||
|
||||
// DefaultKiroContextLength is the default context window size for Kiro models.
|
||||
const DefaultKiroContextLength = 200000
|
||||
|
||||
// DefaultKiroMaxCompletionTokens is the default max completion tokens for Kiro models.
|
||||
const DefaultKiroMaxCompletionTokens = 64000
|
||||
|
||||
// ConvertKiroAPIModels converts Kiro API models to internal ModelInfo format.
|
||||
// It performs the following transformations:
|
||||
// - Normalizes model ID (e.g., claude-sonnet-4.5 → kiro-claude-sonnet-4-5)
|
||||
// - Adds default thinking support metadata
|
||||
// - Sets default context length and max completion tokens if not provided
|
||||
//
|
||||
// Parameters:
|
||||
// - kiroModels: List of models from Kiro API response
|
||||
//
|
||||
// Returns:
|
||||
// - []*ModelInfo: Converted model information list
|
||||
func ConvertKiroAPIModels(kiroModels []*KiroAPIModel) []*ModelInfo {
|
||||
if len(kiroModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
result := make([]*ModelInfo, 0, len(kiroModels))
|
||||
|
||||
for _, km := range kiroModels {
|
||||
// Skip nil models
|
||||
if km == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip models without valid ID
|
||||
if km.ModelID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Normalize the model ID to kiro-* format
|
||||
normalizedID := normalizeKiroModelID(km.ModelID)
|
||||
|
||||
// Create ModelInfo with converted data
|
||||
info := &ModelInfo{
|
||||
ID: normalizedID,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: generateKiroDisplayName(km.ModelName, normalizedID),
|
||||
Description: km.Description,
|
||||
// Use MaxInputTokens from API if available, otherwise use default
|
||||
ContextLength: getContextLength(km.MaxInputTokens),
|
||||
MaxCompletionTokens: DefaultKiroMaxCompletionTokens,
|
||||
// All Kiro models support thinking
|
||||
Thinking: cloneThinkingSupport(DefaultKiroThinkingSupport),
|
||||
}
|
||||
|
||||
result = append(result, info)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GenerateAgenticVariants creates -agentic variants for each model.
|
||||
// Agentic variants are optimized for coding agents with chunked writes.
|
||||
//
|
||||
// Parameters:
|
||||
// - models: Base models to generate variants for
|
||||
//
|
||||
// Returns:
|
||||
// - []*ModelInfo: Combined list of base models and their agentic variants
|
||||
func GenerateAgenticVariants(models []*ModelInfo) []*ModelInfo {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pre-allocate result with capacity for both base models and variants
|
||||
result := make([]*ModelInfo, 0, len(models)*2)
|
||||
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Add the base model first
|
||||
result = append(result, model)
|
||||
|
||||
// Skip if model already has -agentic suffix
|
||||
if strings.HasSuffix(model.ID, "-agentic") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip special models that shouldn't have agentic variants
|
||||
if model.ID == "kiro-auto" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create agentic variant
|
||||
agenticModel := &ModelInfo{
|
||||
ID: model.ID + "-agentic",
|
||||
Object: model.Object,
|
||||
Created: model.Created,
|
||||
OwnedBy: model.OwnedBy,
|
||||
Type: model.Type,
|
||||
DisplayName: model.DisplayName + " (Agentic)",
|
||||
Description: generateAgenticDescription(model.Description),
|
||||
ContextLength: model.ContextLength,
|
||||
MaxCompletionTokens: model.MaxCompletionTokens,
|
||||
Thinking: cloneThinkingSupport(model.Thinking),
|
||||
}
|
||||
|
||||
result = append(result, agenticModel)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// MergeWithStaticMetadata merges dynamic models with static metadata.
|
||||
// Static metadata takes priority for any overlapping fields.
|
||||
// This allows manual overrides for specific models while keeping dynamic discovery.
|
||||
//
|
||||
// Parameters:
|
||||
// - dynamicModels: Models from Kiro API (converted to ModelInfo)
|
||||
// - staticModels: Predefined model metadata (from GetKiroModels())
|
||||
//
|
||||
// Returns:
|
||||
// - []*ModelInfo: Merged model list with static metadata taking priority
|
||||
func MergeWithStaticMetadata(dynamicModels, staticModels []*ModelInfo) []*ModelInfo {
|
||||
if len(dynamicModels) == 0 && len(staticModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build a map of static models for quick lookup
|
||||
staticMap := make(map[string]*ModelInfo, len(staticModels))
|
||||
for _, sm := range staticModels {
|
||||
if sm != nil && sm.ID != "" {
|
||||
staticMap[sm.ID] = sm
|
||||
}
|
||||
}
|
||||
|
||||
// Build result, preferring static metadata where available
|
||||
seenIDs := make(map[string]struct{})
|
||||
result := make([]*ModelInfo, 0, len(dynamicModels)+len(staticModels))
|
||||
|
||||
// First, process dynamic models and merge with static if available
|
||||
for _, dm := range dynamicModels {
|
||||
if dm == nil || dm.ID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip duplicates
|
||||
if _, seen := seenIDs[dm.ID]; seen {
|
||||
continue
|
||||
}
|
||||
seenIDs[dm.ID] = struct{}{}
|
||||
|
||||
// Check if static metadata exists for this model
|
||||
if sm, exists := staticMap[dm.ID]; exists {
|
||||
// Static metadata takes priority - use static model
|
||||
result = append(result, sm)
|
||||
} else {
|
||||
// No static metadata - use dynamic model
|
||||
result = append(result, dm)
|
||||
}
|
||||
}
|
||||
|
||||
// Add any static models not in dynamic list
|
||||
for _, sm := range staticModels {
|
||||
if sm == nil || sm.ID == "" {
|
||||
continue
|
||||
}
|
||||
if _, seen := seenIDs[sm.ID]; seen {
|
||||
continue
|
||||
}
|
||||
seenIDs[sm.ID] = struct{}{}
|
||||
result = append(result, sm)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// normalizeKiroModelID converts Kiro API model IDs to internal format.
|
||||
// Transformation rules:
|
||||
// - Adds "kiro-" prefix if not present
|
||||
// - Replaces dots with hyphens (e.g., 4.5 → 4-5)
|
||||
// - Handles special cases like "auto" → "kiro-auto"
|
||||
//
|
||||
// Examples:
|
||||
// - "claude-sonnet-4.5" → "kiro-claude-sonnet-4-5"
|
||||
// - "claude-opus-4.5" → "kiro-claude-opus-4-5"
|
||||
// - "auto" → "kiro-auto"
|
||||
// - "kiro-claude-sonnet-4-5" → "kiro-claude-sonnet-4-5" (unchanged)
|
||||
func normalizeKiroModelID(modelID string) string {
|
||||
if modelID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Trim whitespace
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
|
||||
// Replace dots with hyphens (e.g., 4.5 → 4-5)
|
||||
normalized := strings.ReplaceAll(modelID, ".", "-")
|
||||
|
||||
// Add kiro- prefix if not present
|
||||
if !strings.HasPrefix(normalized, "kiro-") {
|
||||
normalized = "kiro-" + normalized
|
||||
}
|
||||
|
||||
return normalized
|
||||
}
|
||||
|
||||
// generateKiroDisplayName creates a human-readable display name.
|
||||
// Uses the API-provided model name if available, otherwise generates from ID.
|
||||
func generateKiroDisplayName(modelName, normalizedID string) string {
|
||||
if modelName != "" {
|
||||
return "Kiro " + modelName
|
||||
}
|
||||
|
||||
// Generate from normalized ID by removing kiro- prefix and formatting
|
||||
displayID := strings.TrimPrefix(normalizedID, "kiro-")
|
||||
// Capitalize first letter of each word
|
||||
words := strings.Split(displayID, "-")
|
||||
for i, word := range words {
|
||||
if len(word) > 0 {
|
||||
words[i] = strings.ToUpper(word[:1]) + word[1:]
|
||||
}
|
||||
}
|
||||
return "Kiro " + strings.Join(words, " ")
|
||||
}
|
||||
|
||||
// generateAgenticDescription creates description for agentic variants.
|
||||
func generateAgenticDescription(baseDescription string) string {
|
||||
if baseDescription == "" {
|
||||
return "Optimized for coding agents with chunked writes"
|
||||
}
|
||||
return baseDescription + " (Agentic mode: chunked writes)"
|
||||
}
|
||||
|
||||
// getContextLength returns the context length, using default if not provided.
|
||||
func getContextLength(maxInputTokens int) int {
|
||||
if maxInputTokens > 0 {
|
||||
return maxInputTokens
|
||||
}
|
||||
return DefaultKiroContextLength
|
||||
}
|
||||
|
||||
// cloneThinkingSupport creates a deep copy of ThinkingSupport.
|
||||
// Returns nil if input is nil.
|
||||
func cloneThinkingSupport(ts *ThinkingSupport) *ThinkingSupport {
|
||||
if ts == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
clone := &ThinkingSupport{
|
||||
Min: ts.Min,
|
||||
Max: ts.Max,
|
||||
ZeroAllowed: ts.ZeroAllowed,
|
||||
DynamicAllowed: ts.DynamicAllowed,
|
||||
}
|
||||
|
||||
// Deep copy Levels slice if present
|
||||
if len(ts.Levels) > 0 {
|
||||
clone.Levels = make([]string, len(ts.Levels))
|
||||
copy(clone.Levels, ts.Levels)
|
||||
}
|
||||
|
||||
return clone
|
||||
}
|
||||
@@ -27,7 +27,7 @@ func GetClaudeModels() []*ModelInfo {
|
||||
DisplayName: "Claude 4.5 Sonnet",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-5-20251101",
|
||||
@@ -39,7 +39,7 @@ func GetClaudeModels() []*ModelInfo {
|
||||
Description: "Premium model combining maximum intelligence with practical performance",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-1-20250805",
|
||||
@@ -50,7 +50,7 @@ func GetClaudeModels() []*ModelInfo {
|
||||
DisplayName: "Claude 4.1 Opus",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-20250514",
|
||||
@@ -61,7 +61,7 @@ func GetClaudeModels() []*ModelInfo {
|
||||
DisplayName: "Claude 4 Opus",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-20250514",
|
||||
@@ -72,7 +72,7 @@ func GetClaudeModels() []*ModelInfo {
|
||||
DisplayName: "Claude 4 Sonnet",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-3-7-sonnet-20250219",
|
||||
@@ -83,7 +83,7 @@ func GetClaudeModels() []*ModelInfo {
|
||||
DisplayName: "Claude 3.7 Sonnet",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 8192,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-3-5-haiku-20241022",
|
||||
@@ -160,7 +160,22 @@ func GetGeminiModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Gemini 3 Flash Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
@@ -175,7 +190,7 @@ func GetGeminiModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -240,7 +255,22 @@ func GetGeminiVertexModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
@@ -255,7 +285,68 @@ func GetGeminiVertexModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
// Imagen image generation models - use :predict action
|
||||
{
|
||||
ID: "imagen-4.0-generate-001",
|
||||
Object: "model",
|
||||
Created: 1750000000,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/imagen-4.0-generate-001",
|
||||
Version: "4.0",
|
||||
DisplayName: "Imagen 4.0 Generate",
|
||||
Description: "Imagen 4.0 image generation model",
|
||||
SupportedGenerationMethods: []string{"predict"},
|
||||
},
|
||||
{
|
||||
ID: "imagen-4.0-ultra-generate-001",
|
||||
Object: "model",
|
||||
Created: 1750000000,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/imagen-4.0-ultra-generate-001",
|
||||
Version: "4.0",
|
||||
DisplayName: "Imagen 4.0 Ultra Generate",
|
||||
Description: "Imagen 4.0 Ultra high-quality image generation model",
|
||||
SupportedGenerationMethods: []string{"predict"},
|
||||
},
|
||||
{
|
||||
ID: "imagen-3.0-generate-002",
|
||||
Object: "model",
|
||||
Created: 1740000000,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/imagen-3.0-generate-002",
|
||||
Version: "3.0",
|
||||
DisplayName: "Imagen 3.0 Generate",
|
||||
Description: "Imagen 3.0 image generation model",
|
||||
SupportedGenerationMethods: []string{"predict"},
|
||||
},
|
||||
{
|
||||
ID: "imagen-3.0-fast-generate-001",
|
||||
Object: "model",
|
||||
Created: 1740000000,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/imagen-3.0-fast-generate-001",
|
||||
Version: "3.0",
|
||||
DisplayName: "Imagen 3.0 Fast Generate",
|
||||
Description: "Imagen 3.0 fast image generation model",
|
||||
SupportedGenerationMethods: []string{"predict"},
|
||||
},
|
||||
{
|
||||
ID: "imagen-4.0-fast-generate-001",
|
||||
Object: "model",
|
||||
Created: 1750000000,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/imagen-4.0-fast-generate-001",
|
||||
Version: "4.0",
|
||||
DisplayName: "Imagen 4.0 Fast Generate",
|
||||
Description: "Imagen 4.0 fast image generation model",
|
||||
SupportedGenerationMethods: []string{"predict"},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -317,11 +408,26 @@ func GetGeminiCLIModels() []*ModelInfo {
|
||||
Name: "models/gemini-3-pro-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Pro Preview",
|
||||
Description: "Gemini 3 Pro Preview",
|
||||
Description: "Our most intelligent model with SOTA reasoning and multimodal understanding, and powerful agentic and vibe coding capabilities",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -389,6 +495,21 @@ func GetAIStudioModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "gemini-pro-latest",
|
||||
Object: "model",
|
||||
@@ -582,6 +703,20 @@ func GetOpenAIModels() []*ModelInfo {
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.2-codex",
|
||||
Object: "model",
|
||||
Created: 1765440000,
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
Version: "gpt-5.2",
|
||||
DisplayName: "GPT 5.2 Codex",
|
||||
Description: "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.",
|
||||
ContextLength: 400000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -630,6 +765,13 @@ func GetQwenModels() []*ModelInfo {
|
||||
}
|
||||
}
|
||||
|
||||
// iFlowThinkingSupport is a shared ThinkingSupport configuration for iFlow models
|
||||
// that support thinking mode via chat_template_kwargs.enable_thinking (boolean toggle).
|
||||
// Uses level-based configuration so standard normalization flows apply before conversion.
|
||||
var iFlowThinkingSupport = &ThinkingSupport{
|
||||
Levels: []string{"none", "auto", "minimal", "low", "medium", "high", "xhigh"},
|
||||
}
|
||||
|
||||
// GetIFlowModels returns supported models for iFlow OAuth accounts.
|
||||
func GetIFlowModels() []*ModelInfo {
|
||||
entries := []struct {
|
||||
@@ -645,20 +787,23 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},
|
||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2-exp", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000},
|
||||
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000},
|
||||
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200},
|
||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
||||
{ID: "deepseek-v3", DisplayName: "DeepSeek-V3-671B", Description: "DeepSeek V3 671B", Created: 1734307200},
|
||||
{ID: "qwen3-32b", DisplayName: "Qwen3-32B", Description: "Qwen3 32B", Created: 1747094400},
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
@@ -681,23 +826,63 @@ func GetIFlowModels() []*ModelInfo {
|
||||
type AntigravityModelConfig struct {
|
||||
Thinking *ThinkingSupport
|
||||
MaxCompletionTokens int
|
||||
Name string
|
||||
}
|
||||
|
||||
// GetAntigravityModelConfig returns static configuration for antigravity models.
|
||||
// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup.
|
||||
// Keys use upstream model names returned by the Antigravity models endpoint.
|
||||
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
return map[string]*AntigravityModelConfig{
|
||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"},
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"},
|
||||
"gemini-2.5-computer-use-preview-10-2025": {Name: "models/gemini-2.5-computer-use-preview-10-2025"},
|
||||
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-preview"},
|
||||
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-image-preview"},
|
||||
"gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"rev19-uic3-1p": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}},
|
||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||
"gpt-oss-120b-medium": {},
|
||||
"tab_flash_lite_preview": {},
|
||||
}
|
||||
}
|
||||
|
||||
// LookupStaticModelInfo searches all static model definitions for a model by ID.
|
||||
// Returns nil if no matching model is found.
|
||||
func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
if modelID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
allModels := [][]*ModelInfo{
|
||||
GetClaudeModels(),
|
||||
GetGeminiModels(),
|
||||
GetGeminiVertexModels(),
|
||||
GetGeminiCLIModels(),
|
||||
GetAIStudioModels(),
|
||||
GetOpenAIModels(),
|
||||
GetQwenModels(),
|
||||
GetIFlowModels(),
|
||||
}
|
||||
for _, models := range allModels {
|
||||
for _, m := range models {
|
||||
if m != nil && m.ID == modelID {
|
||||
return m
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check Antigravity static config
|
||||
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil {
|
||||
return &ModelInfo{
|
||||
ID: modelID,
|
||||
Thinking: cfg.Thinking,
|
||||
MaxCompletionTokens: cfg.MaxCompletionTokens,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetGitHubCopilotModels returns the available models for GitHub Copilot.
|
||||
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
|
||||
func GetGitHubCopilotModels() []*ModelInfo {
|
||||
@@ -724,6 +909,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
Description: "OpenAI GPT-5 via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5-mini",
|
||||
@@ -735,6 +921,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
Description: "OpenAI GPT-5 Mini via GitHub Copilot",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5-codex",
|
||||
@@ -746,6 +933,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
Description: "OpenAI GPT-5 Codex via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1",
|
||||
@@ -757,6 +945,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
Description: "OpenAI GPT-5.1 via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1-codex",
|
||||
@@ -768,6 +957,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
Description: "OpenAI GPT-5.1 Codex via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1-codex-mini",
|
||||
@@ -779,6 +969,43 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
Description: "OpenAI GPT-5.1 Codex Mini via GitHub Copilot",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1-codex-max",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "GPT-5.1 Codex Max",
|
||||
Description: "OpenAI GPT-5.1 Codex Max via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.2",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "GPT-5.2",
|
||||
Description: "OpenAI GPT-5.2 via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.2-codex",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "GPT-5.2 Codex",
|
||||
Description: "OpenAI GPT-5.2 Codex via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
},
|
||||
{
|
||||
ID: "claude-haiku-4.5",
|
||||
@@ -790,6 +1017,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
Description: "Anthropic Claude Haiku 4.5 via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4.1",
|
||||
@@ -801,6 +1029,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
Description: "Anthropic Claude Opus 4.1 via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4.5",
|
||||
@@ -812,6 +1041,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
Description: "Anthropic Claude Opus 4.5 via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4",
|
||||
@@ -823,6 +1053,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
Description: "Anthropic Claude Sonnet 4 via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4.5",
|
||||
@@ -834,6 +1065,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
Description: "Anthropic Claude Sonnet 4.5 via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "gemini-2.5-pro",
|
||||
@@ -847,13 +1079,24 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 65536,
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro",
|
||||
ID: "gemini-3-pro-preview",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "Gemini 3 Pro",
|
||||
Description: "Google Gemini 3 Pro via GitHub Copilot",
|
||||
DisplayName: "Gemini 3 Pro (Preview)",
|
||||
Description: "Google Gemini 3 Pro Preview via GitHub Copilot",
|
||||
ContextLength: 1048576,
|
||||
MaxCompletionTokens: 65536,
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "Gemini 3 Flash (Preview)",
|
||||
Description: "Google Gemini 3 Flash Preview via GitHub Copilot",
|
||||
ContextLength: 1048576,
|
||||
MaxCompletionTokens: 65536,
|
||||
},
|
||||
@@ -869,15 +1112,16 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 16384,
|
||||
},
|
||||
{
|
||||
ID: "raptor-mini",
|
||||
ID: "oswe-vscode-prime",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "Raptor Mini",
|
||||
Description: "Raptor Mini via GitHub Copilot",
|
||||
DisplayName: "Raptor mini (Preview)",
|
||||
Description: "Raptor mini via GitHub Copilot",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -886,6 +1130,18 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
func GetKiroModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
// --- Base Models ---
|
||||
{
|
||||
ID: "kiro-auto",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Auto",
|
||||
Description: "Automatic model selection by Kiro",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-opus-4-5",
|
||||
Object: "model",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user