mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-09 15:25:17 +00:00
Compare commits
1650 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
182b31963a | ||
|
|
4f48e5254a | ||
|
|
15dd5db1d7 | ||
|
|
424711b718 | ||
|
|
2b134fc378 | ||
|
|
b9153719b0 | ||
|
|
631e5c8331 | ||
|
|
e9c60a0a67 | ||
|
|
98a1bb5a7f | ||
|
|
ca90487a8c | ||
|
|
1042489f85 | ||
|
|
38277c1ea6 | ||
|
|
ee0c24628f | ||
|
|
3a18f6fcca | ||
|
|
099e734a02 | ||
|
|
a52da26b5d | ||
|
|
522a68a4ea | ||
|
|
a02eda54d0 | ||
|
|
97ef633c57 | ||
|
|
dae8463ba1 | ||
|
|
7c1299922e | ||
|
|
ddcf1f279d | ||
|
|
7e6bb8fdc5 | ||
|
|
9cee8ef87b | ||
|
|
93fb841bcb | ||
|
|
0c05131aeb | ||
|
|
5ebc58fab4 | ||
|
|
2b609dd891 | ||
|
|
a8cbc68c3e | ||
|
|
11a795a01c | ||
|
|
89c428216e | ||
|
|
2695a99623 | ||
|
|
242aecd924 | ||
|
|
ce8cc1ba33 | ||
|
|
ad5253bd2b | ||
|
|
97fdd2e088 | ||
|
|
9397f7049f | ||
|
|
a14d19b92c | ||
|
|
8ae0c05ea6 | ||
|
|
8822f20d17 | ||
|
|
553d6f50ea | ||
|
|
f0e5a5a367 | ||
|
|
f6dfea9357 | ||
|
|
cc8dc7f62c | ||
|
|
a3846ea513 | ||
|
|
8d44be858e | ||
|
|
0e6bb076e9 | ||
|
|
ac135fc7cb | ||
|
|
4e1d09809d | ||
|
|
9e855f8100 | ||
|
|
25680a8259 | ||
|
|
13c93e8cfd | ||
|
|
88aa1b9fd1 | ||
|
|
352cb98ff0 | ||
|
|
ac95e92829 | ||
|
|
8526c2da25 | ||
|
|
68a6cabf8b | ||
|
|
ac0e387da1 | ||
|
|
7fe1d102cb | ||
|
|
5850492a93 | ||
|
|
fdbd4041ca | ||
|
|
ebef1fae2a | ||
|
|
c51851689b | ||
|
|
419bf784ab | ||
|
|
4bbeb92e9a | ||
|
|
b436dad8bc | ||
|
|
6ae15d6c44 | ||
|
|
0468bde0d6 | ||
|
|
1d7329e797 | ||
|
|
48ffc4dee7 | ||
|
|
7ebd8f0c44 | ||
|
|
b680c146c1 | ||
|
|
7d6660d181 | ||
|
|
d8e3d4e2b6 | ||
|
|
d26ad8224d | ||
|
|
5c84d69d42 | ||
|
|
527e4b7f26 | ||
|
|
b48485b42b | ||
|
|
79009bb3d4 | ||
|
|
26fc611f86 | ||
|
|
b43743d4f1 | ||
|
|
179e5434b1 | ||
|
|
9f95b31158 | ||
|
|
5da07eae4c | ||
|
|
835ae178d4 | ||
|
|
c80ab8bf0d | ||
|
|
ce87714ef1 | ||
|
|
0452b869e8 | ||
|
|
d2e5857b82 | ||
|
|
f9b005f21f | ||
|
|
532107b4fa | ||
|
|
c44793789b | ||
|
|
4e99525279 | ||
|
|
7547d1d0b3 | ||
|
|
68934942d0 | ||
|
|
09fec34e1c | ||
|
|
9229708b6c | ||
|
|
914db94e79 | ||
|
|
660bd7eff5 | ||
|
|
b907d21851 | ||
|
|
dd44413ba5 | ||
|
|
10fa0f2062 | ||
|
|
d6cc976d1f | ||
|
|
8aa2cce8c5 | ||
|
|
bf9b2c49df | ||
|
|
77b42c6165 | ||
|
|
446150a747 | ||
|
|
1cbc4834e1 | ||
|
|
30338ecec4 | ||
|
|
9a37defed3 | ||
|
|
c83a057996 | ||
|
|
a8a5d03c33 | ||
|
|
76aa917882 | ||
|
|
6ac9b31e4e | ||
|
|
0ad3e8457f | ||
|
|
444a47ae63 | ||
|
|
725f4fdff4 | ||
|
|
c23e46f45d | ||
|
|
b148820c35 | ||
|
|
134f41496d | ||
|
|
c5838dd58d | ||
|
|
b6ca5ef7ce | ||
|
|
1ae994b4aa | ||
|
|
84e9793e61 | ||
|
|
32e64dacfd | ||
|
|
cc1d8f6629 | ||
|
|
5446cd2b02 | ||
|
|
8de0885b7d | ||
|
|
16243f18fd | ||
|
|
a6ce5f36e6 | ||
|
|
e73cf42e28 | ||
|
|
b45343e812 | ||
|
|
8599b1560e | ||
|
|
8bde8c37c0 | ||
|
|
82df5bf88a | ||
|
|
acb1066de8 | ||
|
|
27c68f5bb2 | ||
|
|
68dd2bfe82 | ||
|
|
65a87815e7 | ||
|
|
b80793ca82 | ||
|
|
601550f238 | ||
|
|
41b1cf2273 | ||
|
|
2baf35b3ef | ||
|
|
846e75b893 | ||
|
|
fc0257d6d9 | ||
|
|
f3c164d345 | ||
|
|
4040b1e766 | ||
|
|
3b4f9f43db | ||
|
|
37a09ecb23 | ||
|
|
0da34d3c2d | ||
|
|
74bf7eda8f | ||
|
|
9032042cfa | ||
|
|
030bf5e6c7 | ||
|
|
d3100085b0 | ||
|
|
f481d25133 | ||
|
|
8c6c90da74 | ||
|
|
24bcfd9c03 | ||
|
|
816fb4c5da | ||
|
|
c1bb77c7c9 | ||
|
|
6bcac3a55a | ||
|
|
fc346f4537 | ||
|
|
43e531a3b6 | ||
|
|
d24ea4ce2a | ||
|
|
2c30c981ae | ||
|
|
aa1da8a858 | ||
|
|
f1e9a787d7 | ||
|
|
4eeec297de | ||
|
|
77cc4ce3a0 | ||
|
|
37dfea1d3f | ||
|
|
e6626c672a | ||
|
|
c66cb0afd2 | ||
|
|
fb48eee973 | ||
|
|
bb44e5ec44 | ||
|
|
c785c1a3ca | ||
|
|
514ae341c8 | ||
|
|
0659ffab75 | ||
|
|
8ce07f38dd | ||
|
|
7cb398d167 | ||
|
|
c3e12c5e58 | ||
|
|
1825fc7503 | ||
|
|
48732ba05e | ||
|
|
acf483c9e6 | ||
|
|
3b3e0d1141 | ||
|
|
7acd428507 | ||
|
|
0aaf177640 | ||
|
|
450d1227bd | ||
|
|
b7588428c5 | ||
|
|
492b9c46f0 | ||
|
|
6e634fe3f9 | ||
|
|
4e26182d14 | ||
|
|
8f97a5f77c | ||
|
|
eb7571936c | ||
|
|
5382764d8a | ||
|
|
49c8ec69d0 | ||
|
|
3b421c8181 | ||
|
|
21d2329947 | ||
|
|
0993413bab | ||
|
|
713388dd7b | ||
|
|
2a4d3e60f3 | ||
|
|
8b5af2ab84 | ||
|
|
e6c7af0fa9 | ||
|
|
837aa6e3aa | ||
|
|
d210be06c2 | ||
|
|
d887716ebd | ||
|
|
5dc1848466 | ||
|
|
9491517b26 | ||
|
|
9370b5bd04 | ||
|
|
abb51a0d93 | ||
|
|
c8d809131b | ||
|
|
dd71c73a9f | ||
|
|
afc8a0f9be | ||
|
|
af8e9ef458 | ||
|
|
cec6f993ad | ||
|
|
950de29f48 | ||
|
|
d6ec33e8e1 | ||
|
|
081cfe806e | ||
|
|
c1c62a6c04 | ||
|
|
a99522224f | ||
|
|
f5d46b9ca2 | ||
|
|
d693d7993b | ||
|
|
5936f9895c | ||
|
|
2fdf5d2793 | ||
|
|
b3da00d2ed | ||
|
|
740277a9f2 | ||
|
|
f91807b6b9 | ||
|
|
57d18bb226 | ||
|
|
10b9c6cb8a | ||
|
|
b24786f8a7 | ||
|
|
7b0eb41ebc | ||
|
|
70949929db | ||
|
|
7c9c89dace | ||
|
|
ef5901c81b | ||
|
|
d4829c82f7 | ||
|
|
a5f4166a9b | ||
|
|
0cbfe7f457 | ||
|
|
f2b1ec4f9e | ||
|
|
1cc21cc45b | ||
|
|
07cf616e2b | ||
|
|
2b8c466e88 | ||
|
|
ca2174ea48 | ||
|
|
c09fb2a79d | ||
|
|
4445a165e9 | ||
|
|
e92e2af71a | ||
|
|
a6bdd9a652 | ||
|
|
349a6349b3 | ||
|
|
00822770ec | ||
|
|
1a0ceda0fc | ||
|
|
b9ae4ab803 | ||
|
|
72add453d2 | ||
|
|
2789396435 | ||
|
|
61da7bd981 | ||
|
|
ae4c502792 | ||
|
|
ec6068060b | ||
|
|
ecb01d3dcd | ||
|
|
22c0c00bd4 | ||
|
|
9eb3e7a6c4 | ||
|
|
357c191510 | ||
|
|
5db244af76 | ||
|
|
dc375d1b74 | ||
|
|
9c040445af | ||
|
|
fff866424e | ||
|
|
2d12becfd6 | ||
|
|
252f7e0751 | ||
|
|
b2b17528cb | ||
|
|
55f938164b | ||
|
|
76294f0c59 | ||
|
|
2bcee78c6e | ||
|
|
7fe8246a9f | ||
|
|
93fe58e31e | ||
|
|
e5b5dc870f | ||
|
|
a54877c023 | ||
|
|
bb86a0c0c4 | ||
|
|
5fa23c7f41 | ||
|
|
f9a09b7f23 | ||
|
|
b0cde626fe | ||
|
|
e42ef9a95d | ||
|
|
abf1629ec7 | ||
|
|
73dc0b10b8 | ||
|
|
2ea95266e3 | ||
|
|
922d4141c0 | ||
|
|
1f8f198c45 | ||
|
|
c55275342c | ||
|
|
9261b0c20b | ||
|
|
7cc725496e | ||
|
|
5726a99c80 | ||
|
|
b5756bf729 | ||
|
|
709d999f9f | ||
|
|
24c18614f0 | ||
|
|
603f06a762 | ||
|
|
98f0a3e3bd | ||
|
|
e186ccb0d4 | ||
|
|
8fc0b08b70 | ||
|
|
52a257dc24 | ||
|
|
a12d907f55 | ||
|
|
453aaf8774 | ||
|
|
1b1ab1fb9b | ||
|
|
a9d0bb72da | ||
|
|
d328e54e4b | ||
|
|
5a7932cba4 | ||
|
|
1dbeb0827a | ||
|
|
2c8821891c | ||
|
|
0a2555b0f3 | ||
|
|
020df41efe | ||
|
|
f8f8cf17ce | ||
|
|
f31f7f701a | ||
|
|
b5fe78eb70 | ||
|
|
d1f667cf8d | ||
|
|
54ad7c1b6b | ||
|
|
d560c20c26 | ||
|
|
5abeca1f9e | ||
|
|
294eac3a88 | ||
|
|
a31104020c | ||
|
|
65bec4d734 | ||
|
|
edb2993838 | ||
|
|
c0d8e0dec7 | ||
|
|
795da13d5d | ||
|
|
55789df275 | ||
|
|
9e652a3540 | ||
|
|
46a6782065 | ||
|
|
c359f61859 | ||
|
|
908c8eab5b | ||
|
|
f5f2c69233 | ||
|
|
63d4de5eea | ||
|
|
af15083496 | ||
|
|
c4722e42b1 | ||
|
|
f9a991365f | ||
|
|
6df16bedba | ||
|
|
632a2fd2f2 | ||
|
|
5626637fbd | ||
|
|
2db89211a9 | ||
|
|
587371eb14 | ||
|
|
75818b1e25 | ||
|
|
a45c6defa7 | ||
|
|
cbe56955a9 | ||
|
|
8ea6ac913d | ||
|
|
ae1e8a5191 | ||
|
|
b3ccc55f09 | ||
|
|
40bee3e8d9 | ||
|
|
1ce56d7413 | ||
|
|
41a78be3a2 | ||
|
|
1ff5de9a31 | ||
|
|
46a6853046 | ||
|
|
4b2d40bd67 | ||
|
|
726f1a590c | ||
|
|
575881cb59 | ||
|
|
d02df0141b | ||
|
|
e4bc9da913 | ||
|
|
8c6be49625 | ||
|
|
c727e4251f | ||
|
|
99266be998 | ||
|
|
d0f3fd96f8 | ||
|
|
f361b2716d | ||
|
|
086d8d0d0b | ||
|
|
627dee1dac | ||
|
|
93147dddeb | ||
|
|
c0f9b15a58 | ||
|
|
6f2fbdcbae | ||
|
|
55c3197fb8 | ||
|
|
65debb874f | ||
|
|
3caadac003 | ||
|
|
6a9e3a6b84 | ||
|
|
269972440a | ||
|
|
cce13e6ad2 | ||
|
|
8a565dcad8 | ||
|
|
d536110404 | ||
|
|
48e957ddff | ||
|
|
94563d622c | ||
|
|
5a2cf0d53c | ||
|
|
2573358173 | ||
|
|
09cd3cff91 | ||
|
|
ab0bf1b517 | ||
|
|
58e09f8e5f | ||
|
|
2334a2b174 | ||
|
|
bc61bf36b2 | ||
|
|
7726a44ca2 | ||
|
|
dc55fb0ce3 | ||
|
|
a146c6c0aa | ||
|
|
4c133d3ea9 | ||
|
|
544238772a | ||
|
|
f3ccd85ba1 | ||
|
|
dc279de443 | ||
|
|
bf1634bda0 | ||
|
|
166d2d24d9 | ||
|
|
4cbcc835d1 | ||
|
|
b93026d83a | ||
|
|
5ed2133ff9 | ||
|
|
e9dd44e623 | ||
|
|
cc8c4ffb5f | ||
|
|
1510bfcb6f | ||
|
|
bcd2208b51 | ||
|
|
09b19f5c4e | ||
|
|
7b01ca0e2e | ||
|
|
9c65e17a21 | ||
|
|
fe6fc628ed | ||
|
|
8192eeabc8 | ||
|
|
c3f1cdd7e5 | ||
|
|
c6bd91b86b | ||
|
|
ce0c6aa82b | ||
|
|
349ddcaa89 | ||
|
|
bb9fe52f1e | ||
|
|
afe4c1bfb7 | ||
|
|
3c85d2a4d7 | ||
|
|
865af9f19e | ||
|
|
2b97cb98b5 | ||
|
|
938a799263 | ||
|
|
e17d4f8d98 | ||
|
|
c8cae1f74d | ||
|
|
0040d78496 | ||
|
|
2615f489d6 | ||
|
|
896de027cc | ||
|
|
fc329ebf37 | ||
|
|
15bc99f6ea | ||
|
|
91841a5519 | ||
|
|
eaab1d6824 | ||
|
|
0cfe310df6 | ||
|
|
918b6955e4 | ||
|
|
3ec7991e5f | ||
|
|
532fbf00d4 | ||
|
|
45b6fffd7f | ||
|
|
5a3eb08739 | ||
|
|
0dff329162 | ||
|
|
49c1740b47 | ||
|
|
3fbee51e9f | ||
|
|
a3dc56d2a0 | ||
|
|
63643c44a1 | ||
|
|
1d93608dbe | ||
|
|
d125b7de92 | ||
|
|
d5654ee316 | ||
|
|
3b34521ad9 | ||
|
|
7197fb350b | ||
|
|
6e349bfcc7 | ||
|
|
234056072d | ||
|
|
76330f4bff | ||
|
|
d468eec6ec | ||
|
|
40e85a6759 | ||
|
|
9bc6cc5b41 | ||
|
|
d109be159c | ||
|
|
eddf31e55b | ||
|
|
7e9d0db6aa | ||
|
|
2f1874ede5 | ||
|
|
6b83585b53 | ||
|
|
78ef04fcf1 | ||
|
|
b7e4f00c5f | ||
|
|
c20507c15e | ||
|
|
f7d0019df7 | ||
|
|
52364af5bf | ||
|
|
f410dd0440 | ||
|
|
eb5582c17c | ||
|
|
1c6cb2bec3 | ||
|
|
80b5e79e75 | ||
|
|
d182e893b6 | ||
|
|
2e8d49a641 | ||
|
|
6abd7d27d9 | ||
|
|
8fa12af403 | ||
|
|
77586ed7d3 | ||
|
|
394497fb2f | ||
|
|
fc7b6ef086 | ||
|
|
98edcad39d | ||
|
|
cc116ce67d | ||
|
|
1187aa8222 | ||
|
|
a35d66443b | ||
|
|
40ad4a42ea | ||
|
|
dc9b4dd017 | ||
|
|
68cb81a258 | ||
|
|
16693053f5 | ||
|
|
40efc2ba43 | ||
|
|
4e3bad3907 | ||
|
|
c874f19f2a | ||
|
|
f5f26f0cbe | ||
|
|
e7e3ca1efb | ||
|
|
4b00312fef | ||
|
|
c5fd3db01e | ||
|
|
e35ffaa925 | ||
|
|
f870a9d2a7 | ||
|
|
165e03f3a7 | ||
|
|
86bdb7808c | ||
|
|
b4e034be1c | ||
|
|
84fcebf538 | ||
|
|
74d9a1ffed | ||
|
|
a5a25dec57 | ||
|
|
c71905e5e8 | ||
|
|
bc78d668ac | ||
|
|
e93eebc2e9 | ||
|
|
5bd0896ad7 | ||
|
|
09ecfbcaed | ||
|
|
f0bd14b64f | ||
|
|
14f044ce4f | ||
|
|
88872baffc | ||
|
|
dbecf5330e | ||
|
|
1c0e102637 | ||
|
|
6b6b343922 | ||
|
|
f7d82fda3f | ||
|
|
706590c62a | ||
|
|
25c6b479c7 | ||
|
|
7cf9ff0345 | ||
|
|
209d74062a | ||
|
|
d86b13c9cb | ||
|
|
075e3ab69e | ||
|
|
49ef22ab78 | ||
|
|
ae4638712e | ||
|
|
c1c9483752 | ||
|
|
6c65fdf54b | ||
|
|
4874253d1e | ||
|
|
b72250349f | ||
|
|
116573311f | ||
|
|
4af712544d | ||
|
|
3f9c9591bd | ||
|
|
1548c567ab | ||
|
|
5b23fc570c | ||
|
|
04e1c7a05a | ||
|
|
9181e72204 | ||
|
|
b854ee4680 | ||
|
|
533a6bd15c | ||
|
|
45546c1cf7 | ||
|
|
e2169e3987 | ||
|
|
e85305c815 | ||
|
|
8d4554bf17 | ||
|
|
f628e4dcbb | ||
|
|
7accae4b6a | ||
|
|
3354fae391 | ||
|
|
4939865f6d | ||
|
|
3da7f7482e | ||
|
|
9072b029b2 | ||
|
|
c296cfb8c0 | ||
|
|
2707377fcb | ||
|
|
259f586ff7 | ||
|
|
d885b81f23 | ||
|
|
fe6bffd080 | ||
|
|
1a81e8a98a | ||
|
|
0b889c6028 | ||
|
|
f6bb0011f9 | ||
|
|
fcdd91895e | ||
|
|
8dc4fc4ff5 | ||
|
|
9e9a860bda | ||
|
|
6cd32028c3 | ||
|
|
ebd58ef33a | ||
|
|
92791194e5 | ||
|
|
1f7c58f7ce | ||
|
|
b9cdc2f54c | ||
|
|
5e23975d6e | ||
|
|
420937c848 | ||
|
|
e1a353ca20 | ||
|
|
250f212fa3 | ||
|
|
a275db3fdb | ||
|
|
95a3e32a12 | ||
|
|
233be6272a | ||
|
|
47cb52385e | ||
|
|
3c7a5afdcc | ||
|
|
5dc936a9a4 | ||
|
|
ba168ec003 | ||
|
|
a12e22c66f | ||
|
|
4c50a7281a | ||
|
|
80d3fa384e | ||
|
|
38f7e754ca | ||
|
|
157f16d3b2 | ||
|
|
b927b0cc6c | ||
|
|
493969a742 | ||
|
|
354f6582b2 | ||
|
|
fe3ebe3532 | ||
|
|
b45ede0b71 | ||
|
|
ac802a4646 | ||
|
|
a406ca2d5a | ||
|
|
6a258ff841 | ||
|
|
4649cadcb5 | ||
|
|
c287378167 | ||
|
|
0de86a390d | ||
|
|
c82d8e250a | ||
|
|
73db4e64f6 | ||
|
|
69ca0a8fac | ||
|
|
3b04e11544 | ||
|
|
e0927afa40 | ||
|
|
f97d9f3e11 | ||
|
|
b43610159f | ||
|
|
6d8609e457 | ||
|
|
dcd0ae7467 | ||
|
|
d216adeffc | ||
|
|
bb09708c02 | ||
|
|
1150d972a1 | ||
|
|
13bb7cf704 | ||
|
|
8bce696a7c | ||
|
|
6db8d2a28e | ||
|
|
2854e04bbb | ||
|
|
f3fd7a9fbd | ||
|
|
f99cddf97f | ||
|
|
0606a7762c | ||
|
|
f887f9985d | ||
|
|
550da0cee8 | ||
|
|
e662c020a9 | ||
|
|
7ff3936efe | ||
|
|
29594086c0 | ||
|
|
b0433c9f2a | ||
|
|
b1204b1423 | ||
|
|
43ca112fff | ||
|
|
24cf7fa6a2 | ||
|
|
bf66bcad86 | ||
|
|
f36a5f5654 | ||
|
|
c1facdff67 | ||
|
|
0263f9d35b | ||
|
|
101498e737 | ||
|
|
4ee46bc9f2 | ||
|
|
c3e94a8277 | ||
|
|
fafef32b9e | ||
|
|
1e764de0a8 | ||
|
|
b3b8d71dfc | ||
|
|
ca29c42805 | ||
|
|
fcefa2c820 | ||
|
|
6b6d030ed3 | ||
|
|
fd5b669c87 | ||
|
|
30d832c9b1 | ||
|
|
2448691136 | ||
|
|
e7cd7b5243 | ||
|
|
33f89a2609 | ||
|
|
403a731e22 | ||
|
|
3631fab7e2 | ||
|
|
b3d292a5f9 | ||
|
|
9293c685e0 | ||
|
|
38094a2339 | ||
|
|
538039f583 | ||
|
|
ca796510e9 | ||
|
|
d0d66cdcb7 | ||
|
|
d7d54fa2cc | ||
|
|
31649325f0 | ||
|
|
3a43ecb19b | ||
|
|
a709e5a12d | ||
|
|
f0ac77197b | ||
|
|
da0bbf2a3f | ||
|
|
295f34d7f0 | ||
|
|
c41ce77eea | ||
|
|
876b86ff91 | ||
|
|
acdfa1c87f | ||
|
|
4eb1e6093f | ||
|
|
189a066807 | ||
|
|
d0bada7a43 | ||
|
|
9dc0e6d08b | ||
|
|
8510fc313e | ||
|
|
2666708c30 | ||
|
|
f2b0ce13d9 | ||
|
|
b8652b7387 | ||
|
|
b18b2ebe9f | ||
|
|
9e5b1d24e8 | ||
|
|
a7dae6ad52 | ||
|
|
e93e05ae25 | ||
|
|
c8c27325dc | ||
|
|
c3b6f3918c | ||
|
|
bbb55a8ab4 | ||
|
|
04b2290927 | ||
|
|
53920b0399 | ||
|
|
58290760a9 | ||
|
|
8f522eed43 | ||
|
|
3dc001a9d2 | ||
|
|
ee54ee8825 | ||
|
|
2395b7a180 | ||
|
|
7583193c2a | ||
|
|
7cc3bd4ba0 | ||
|
|
88a0f095e8 | ||
|
|
c65f64dce0 | ||
|
|
d18cd217e1 | ||
|
|
ba4a1ab433 | ||
|
|
decddb521e | ||
|
|
33ab3a99f0 | ||
|
|
de6b1ada5d | ||
|
|
e08f48c7a1 | ||
|
|
851712a49e | ||
|
|
9e34323a40 | ||
|
|
95096bc3fc | ||
|
|
70897247b2 | ||
|
|
9c341f5aa5 | ||
|
|
f74a688fb9 | ||
|
|
e3e741d0be | ||
|
|
7c7c5fd967 | ||
|
|
fe8c7a62aa | ||
|
|
2af4a8dc12 | ||
|
|
0f53b952b2 | ||
|
|
7b2ae7377a | ||
|
|
c2ab288c7d | ||
|
|
dbb433fcf8 | ||
|
|
2abf00b5a6 | ||
|
|
275839e5c9 | ||
|
|
f30ffd5f5e | ||
|
|
bc9a24d705 | ||
|
|
2c879f13ef | ||
|
|
07b4a08979 | ||
|
|
497339f055 | ||
|
|
7f612bb069 | ||
|
|
5743b78694 | ||
|
|
2e6a2b655c | ||
|
|
cb47ac21bf | ||
|
|
a1394b4596 | ||
|
|
9e97948f03 | ||
|
|
8f780e7280 | ||
|
|
f7bfa8a05c | ||
|
|
46c6fb1e7a | ||
|
|
9f9fec5d4c | ||
|
|
e95be10485 | ||
|
|
f3d58fa0ce | ||
|
|
8c0eaa1f71 | ||
|
|
405df58f72 | ||
|
|
e7f13aa008 | ||
|
|
7cb6a9b89a | ||
|
|
9aa5344c29 | ||
|
|
8ba0ebbd2a | ||
|
|
c65407ab9f | ||
|
|
9e59685212 | ||
|
|
4a4dfaa910 | ||
|
|
0d6ecb0191 | ||
|
|
f16461bfe7 | ||
|
|
9fccc86b71 | ||
|
|
74683560a7 | ||
|
|
1e4f9dd438 | ||
|
|
b9ff916494 | ||
|
|
9bf4a0cad2 | ||
|
|
c32e2a8196 | ||
|
|
873d41582f | ||
|
|
6fb7d85558 | ||
|
|
6da7ed53f2 | ||
|
|
d5e3e32d58 | ||
|
|
f353a54555 | ||
|
|
1d6e2e751d | ||
|
|
cc50b63422 | ||
|
|
15ae83a15b | ||
|
|
81b369aed9 | ||
|
|
c8620d1633 | ||
|
|
ecc850bfb7 | ||
|
|
19b4ef33e0 | ||
|
|
7ca045d8b9 | ||
|
|
25b9df478c | ||
|
|
abfca6aab2 | ||
|
|
3c71c075db | ||
|
|
9c2992bfb2 | ||
|
|
269a1c5452 | ||
|
|
22ce65ac72 | ||
|
|
a2f8f59192 | ||
|
|
8c7c446f33 | ||
|
|
51611c25d7 | ||
|
|
eb1bbaa63b | ||
|
|
30a59168d7 | ||
|
|
4c8026ac3d | ||
|
|
8aeb4b7d54 | ||
|
|
b2172cb047 | ||
|
|
c8884f5e25 | ||
|
|
d9c6317c84 | ||
|
|
d29ec95526 | ||
|
|
ef4508dbc8 | ||
|
|
f775e46fe2 | ||
|
|
65ad5c0c9d | ||
|
|
88bf4e77ec | ||
|
|
194f66ca9c | ||
|
|
a4f8015caa | ||
|
|
ffd129909e | ||
|
|
9332316383 | ||
|
|
6dcbbf64c3 | ||
|
|
c9aa1ff99d | ||
|
|
2ce3553612 | ||
|
|
2e14f787d4 | ||
|
|
523b41ccd2 | ||
|
|
09970dc7af | ||
|
|
d81abd401c | ||
|
|
a6cba25bc1 | ||
|
|
c6fa1d0e67 | ||
|
|
ac56e1e88b | ||
|
|
a9ee971e1c | ||
|
|
73cef3a25a | ||
|
|
9b72ea9efa | ||
|
|
9f364441e8 | ||
|
|
e49a1c07bf | ||
|
|
5364a2471d | ||
|
|
fef4fdb0eb | ||
|
|
c2bf600a39 | ||
|
|
8d9f4edf9b | ||
|
|
020e61d0da | ||
|
|
6184c43319 | ||
|
|
2cbe4a790c | ||
|
|
68b3565d7b | ||
|
|
3f385a8572 | ||
|
|
9823dc35e1 | ||
|
|
059bfee91b | ||
|
|
7beaf0eaa2 | ||
|
|
1fef90ff58 | ||
|
|
8447fd27a0 | ||
|
|
7831cba9f6 | ||
|
|
e02b2d58d5 | ||
|
|
28726632a9 | ||
|
|
0f63d973be | ||
|
|
3b26129c82 | ||
|
|
d4bb4e6624 | ||
|
|
fa2abd560a | ||
|
|
0766c49f93 | ||
|
|
a7ffc77e3d | ||
|
|
e641fde25c | ||
|
|
564c2d763e | ||
|
|
5717c7f2f4 | ||
|
|
8734d4cb90 | ||
|
|
2f6004d74a | ||
|
|
08779cc8a8 | ||
|
|
5baa753539 | ||
|
|
92fb6b012a | ||
|
|
ead98e4bca | ||
|
|
a1634909e8 | ||
|
|
8f06f6a9ed | ||
|
|
ace7c0ccb4 | ||
|
|
f87fe0a0e8 | ||
|
|
87edc6f35e | ||
|
|
1d2fe55310 | ||
|
|
c175821cc4 | ||
|
|
239a28793c | ||
|
|
c421d653e7 | ||
|
|
2542c2920d | ||
|
|
52e46ced1b | ||
|
|
cf9daf470c | ||
|
|
ac7738bdeb | ||
|
|
2d9f6c104c | ||
|
|
5d0460ece2 | ||
|
|
140d6211cc | ||
|
|
60f9a1442c | ||
|
|
cb6caf3f87 | ||
|
|
c9301a6d18 | ||
|
|
0e77e93e5d | ||
|
|
99c7abbbf1 | ||
|
|
8f511ac33c | ||
|
|
1046152119 | ||
|
|
f88228f1c5 | ||
|
|
62e2b672d9 | ||
|
|
03005b5d29 | ||
|
|
c7e8830a56 | ||
|
|
d5ef4a6d15 | ||
|
|
97b67e0e49 | ||
|
|
dd6d78cb31 | ||
|
|
46433a25f8 | ||
|
|
b4e070697d | ||
|
|
c8843edb81 | ||
|
|
f89feb881c | ||
|
|
dbba71028e | ||
|
|
8549a92e9a | ||
|
|
109cffc010 | ||
|
|
f8f3ad84fc | ||
|
|
93d7883513 | ||
|
|
015a3e8a83 | ||
|
|
bc7167e9fe | ||
|
|
384578a88c | ||
|
|
6b074653f2 | ||
|
|
65b4e1ec6c | ||
|
|
06afa29f2d | ||
|
|
6600d58ba2 | ||
|
|
25e9be3ced | ||
|
|
ccb2aaf2fe | ||
|
|
961c6f67da | ||
|
|
dc4305f75a | ||
|
|
4dc7af5a5d | ||
|
|
902bea24b4 | ||
|
|
778cf4af9e | ||
|
|
c3ef46f409 | ||
|
|
4721c58d9c | ||
|
|
aa0b63e214 | ||
|
|
3c4e7997c3 | ||
|
|
1afc3a5f65 | ||
|
|
ea3d22831e | ||
|
|
3b4d6d359b | ||
|
|
48cba39a12 | ||
|
|
bca244df67 | ||
|
|
cec4e251bd | ||
|
|
526dd866ba | ||
|
|
c29839d2ed | ||
|
|
b31ddc7bf1 | ||
|
|
22e1ad3d8a | ||
|
|
f571b1deb0 | ||
|
|
67f8732683 | ||
|
|
2b387e169b | ||
|
|
199cf480b0 | ||
|
|
18daa023cb | ||
|
|
4ad6189487 | ||
|
|
8950d92682 | ||
|
|
fe5b3c80cb | ||
|
|
0ffcce3ec8 | ||
|
|
e0ffec885c | ||
|
|
ff4ff6bc2f | ||
|
|
f4fcfc5867 | ||
|
|
7248f65c36 | ||
|
|
5c40a2db21 | ||
|
|
d6111344c5 | ||
|
|
086eb3df7a | ||
|
|
ee2976cca0 | ||
|
|
8bc6df329f | ||
|
|
bcd4d9595f | ||
|
|
5a77b7728e | ||
|
|
1fbbba6f59 | ||
|
|
847be0e99d | ||
|
|
f6a2d072e6 | ||
|
|
ed8b0f25ee | ||
|
|
6e4a602c60 | ||
|
|
2262479365 | ||
|
|
33d66959e9 | ||
|
|
7f1b2b3f6e | ||
|
|
40ee065eff | ||
|
|
a75fb6af90 | ||
|
|
72f2125668 | ||
|
|
e8f5888d8e | ||
|
|
0b06d637e7 | ||
|
|
496f6770a5 | ||
|
|
5a7e5bd870 | ||
|
|
6f8a8f8136 | ||
|
|
5df195ea82 | ||
|
|
f82f70df5c | ||
|
|
5a2bf191fc | ||
|
|
a235fb1507 | ||
|
|
0d66522ed8 | ||
|
|
b163f8ed9e | ||
|
|
83e5f60b8b | ||
|
|
5b433f962f | ||
|
|
a1da6ff5ac | ||
|
|
5977af96a0 | ||
|
|
9b33fbf1cd | ||
|
|
43652d044c | ||
|
|
b1b379ea18 | ||
|
|
21ac161b21 | ||
|
|
94e979865e | ||
|
|
6c324f2c8b | ||
|
|
e0194d8511 | ||
|
|
216dafe44b | ||
|
|
d7dc9660af | ||
|
|
e0e30df323 | ||
|
|
543dfd67e0 | ||
|
|
bbd3eafde0 | ||
|
|
e9cd355893 | ||
|
|
c3e39267b8 | ||
|
|
b477aff611 | ||
|
|
28bd1323a2 | ||
|
|
220ca45f74 | ||
|
|
70a82d80ac | ||
|
|
ac626111ac | ||
|
|
8f6740fcef | ||
|
|
d829ac4cf7 | ||
|
|
f064f6e59d | ||
|
|
5bb9c2a2bd | ||
|
|
0b5bbe9234 | ||
|
|
14c74e5e84 | ||
|
|
6448d0ee7c | ||
|
|
b0c17af2cf | ||
|
|
8f27fd5c42 | ||
|
|
a9823ba58a | ||
|
|
8cfe26f10c | ||
|
|
80db2dc254 | ||
|
|
e8e3bc8616 | ||
|
|
ab5f5386e4 | ||
|
|
bc3195c8d8 | ||
|
|
6494330c6b | ||
|
|
89e34bf1e6 | ||
|
|
2574eec2ed | ||
|
|
514b9bf9fc | ||
|
|
4d7f389b69 | ||
|
|
95f87d5669 | ||
|
|
c83365a349 | ||
|
|
6b3604cf2b | ||
|
|
af6bdca14f | ||
|
|
58d45b4d58 | ||
|
|
1906ebcfce | ||
|
|
1c773c428f | ||
|
|
e785bfcd12 | ||
|
|
47dacce6ea | ||
|
|
dcac3407ab | ||
|
|
7004295e1d | ||
|
|
ee62ef4745 | ||
|
|
ef6bafbf7e | ||
|
|
ed28b71e87 | ||
|
|
d47b7dc79a | ||
|
|
49b9709ce5 | ||
|
|
a2eba2cdf5 | ||
|
|
3d01b3cfe8 | ||
|
|
af2efa6f7e | ||
|
|
d73b61d367 | ||
|
|
d3533f81fc | ||
|
|
59a448b645 | ||
|
|
3de7a7f0cd | ||
|
|
4adb9eed77 | ||
|
|
b6a0f7a07f | ||
|
|
b2566368f8 | ||
|
|
1b2f907671 | ||
|
|
bda04eed8a | ||
|
|
e0735977b5 | ||
|
|
67985d8226 | ||
|
|
cbcb061812 | ||
|
|
9fc2e1b3c8 | ||
|
|
3b484aea9e | ||
|
|
963a0950fa | ||
|
|
1fb4f2b12e | ||
|
|
f4ba1ab910 | ||
|
|
2662f91082 | ||
|
|
f5967069f2 | ||
|
|
80f5523685 | ||
|
|
c1db2c7d7c | ||
|
|
5e5d8142f9 | ||
|
|
b01619b441 | ||
|
|
109cf3928a | ||
|
|
4794645dec | ||
|
|
f861bd6a94 | ||
|
|
6dbfdd140d | ||
|
|
aa8526edc0 | ||
|
|
ac3ca0ad8e | ||
|
|
fe6043aec7 | ||
|
|
386ccffed4 | ||
|
|
08d21b76e2 | ||
|
|
ffddd1c90a | ||
|
|
33aa665555 | ||
|
|
00280b6fe8 | ||
|
|
08e8fddf73 | ||
|
|
8f8dfd081b | ||
|
|
9f1b445c7c | ||
|
|
ae933dfe14 | ||
|
|
5d33d6b8ea | ||
|
|
e124db723b | ||
|
|
05444cf32d | ||
|
|
478aff1189 | ||
|
|
8edbda57cf | ||
|
|
52760a4eaa | ||
|
|
bc32096e9c | ||
|
|
821249a5ed | ||
|
|
2331b9a2e7 | ||
|
|
ee33863b47 | ||
|
|
cd22c849e2 | ||
|
|
f0e73efda2 | ||
|
|
3156109c71 | ||
|
|
6762e081f3 | ||
|
|
5ca3508284 | ||
|
|
771fec9447 | ||
|
|
7815ee338d | ||
|
|
44b6c872e2 | ||
|
|
7a77b23f2d | ||
|
|
672e8549c0 | ||
|
|
66f5269a23 | ||
|
|
4eaf769894 | ||
|
|
ebec293497 | ||
|
|
e02ceecd35 | ||
|
|
9116392a45 | ||
|
|
c8b33a8cc3 | ||
|
|
dca8d5ded8 | ||
|
|
2a7fd1e897 | ||
|
|
b9d1e70ac2 | ||
|
|
fdf5720217 | ||
|
|
f40bd0cd51 | ||
|
|
e33676bb87 | ||
|
|
b1f1cee1e5 | ||
|
|
a1ecc9ab00 | ||
|
|
2a663d5cba | ||
|
|
ba486ca6b7 | ||
|
|
750b930679 | ||
|
|
3902fd7501 | ||
|
|
4fc3d5e935 | ||
|
|
2d2f4572a7 | ||
|
|
8f4c46f38d | ||
|
|
b6ba51bc2a | ||
|
|
6a66d32d37 | ||
|
|
8d15723195 | ||
|
|
736e0aae86 | ||
|
|
8bf3305b2b | ||
|
|
d00e3ea973 | ||
|
|
89db4e9481 | ||
|
|
e332419081 | ||
|
|
19232c6388 | ||
|
|
e998b1229a | ||
|
|
bbed134bd1 | ||
|
|
47b9503112 | ||
|
|
3b9253c2be | ||
|
|
d241359153 | ||
|
|
f4d4249ba5 | ||
|
|
06075c0e93 | ||
|
|
e85c9d9322 | ||
|
|
cb56cb250e | ||
|
|
e0381a6ae0 | ||
|
|
2c01b2ef64 | ||
|
|
e947266743 | ||
|
|
c6b0e85b54 | ||
|
|
26efbed05c | ||
|
|
96340bf136 | ||
|
|
b055e00c1a | ||
|
|
414db44c00 | ||
|
|
857c880f99 | ||
|
|
ce7474d953 | ||
|
|
70fdd70b84 | ||
|
|
08ab6a7d77 | ||
|
|
5c95129884 | ||
|
|
9fa2a7e9df | ||
|
|
d443c86620 | ||
|
|
7be3f1c36c | ||
|
|
f6ab6d97b9 | ||
|
|
bc866bac49 | ||
|
|
50e6d845f4 | ||
|
|
a8cb01819d | ||
|
|
1a99cfded4 | ||
|
|
530273906b | ||
|
|
06ddf575d9 | ||
|
|
cf369d4684 | ||
|
|
3099114cbb | ||
|
|
44b63f0767 | ||
|
|
6705d20194 | ||
|
|
a38a9c0b0f | ||
|
|
8286caa366 | ||
|
|
bd1ec8424d | ||
|
|
225e2c6797 | ||
|
|
d8fc485513 | ||
|
|
f137eb0ac4 | ||
|
|
f39a460487 | ||
|
|
ee171bc563 | ||
|
|
a95428f204 | ||
|
|
cb3bdffb43 | ||
|
|
48f19aab51 | ||
|
|
48f6d7abdf | ||
|
|
79fbcb3ec4 | ||
|
|
0e4148b229 | ||
|
|
3ca5fb1046 | ||
|
|
a091d12f4e | ||
|
|
e3d8d726e6 | ||
|
|
457924828a | ||
|
|
aca2ef6359 | ||
|
|
ade7194792 | ||
|
|
0f51e73baa | ||
|
|
3a436e116a | ||
|
|
d06e2dc83c | ||
|
|
336867853b | ||
|
|
6403ff4ec4 | ||
|
|
d222469b44 | ||
|
|
790a17ce98 | ||
|
|
d473c952fb | ||
|
|
7646a2b877 | ||
|
|
62090f2568 | ||
|
|
d35152bbef | ||
|
|
c281f4cbaf | ||
|
|
09455f9e85 | ||
|
|
c8e72ba0dc | ||
|
|
375ef252ab | ||
|
|
ee552f8720 | ||
|
|
2e88c4858e | ||
|
|
3f50da85c1 | ||
|
|
8be06255f7 | ||
|
|
60936b5185 | ||
|
|
72274099aa | ||
|
|
b7f7b3a1d8 | ||
|
|
dcae098e23 | ||
|
|
618606966f | ||
|
|
05f249d77f | ||
|
|
2eb05ec640 | ||
|
|
3ce0d76aa4 | ||
|
|
a00b79d9be | ||
|
|
9fe6a215e6 | ||
|
|
33e53a2a56 | ||
|
|
cd5b80785f | ||
|
|
54f71aa273 | ||
|
|
3f949b7f84 | ||
|
|
cf8b2dcc85 | ||
|
|
8e24d9dc34 | ||
|
|
443c4538bb | ||
|
|
a7fc2ee4cf | ||
|
|
8e749ac22d | ||
|
|
69e09d9bc7 | ||
|
|
ed57d82bc1 | ||
|
|
06ad527e8c | ||
|
|
7af5a90a0b | ||
|
|
7551faff79 | ||
|
|
b7409dd2de | ||
|
|
5ba325a8fc | ||
|
|
d502840f91 | ||
|
|
99238a4b59 | ||
|
|
6d43a2ff9a | ||
|
|
cdb9c2e6e8 | ||
|
|
3faa1ca9af | ||
|
|
9d975e0375 | ||
|
|
2a6d8b78d4 | ||
|
|
671558a822 | ||
|
|
6b80ec79a0 | ||
|
|
d3f4783a24 | ||
|
|
1cb6bdbc87 | ||
|
|
96ddfc1f24 | ||
|
|
c169b32570 | ||
|
|
36a512fdf2 | ||
|
|
26fbb77901 | ||
|
|
a277302262 | ||
|
|
969c1a5b72 | ||
|
|
872339bceb | ||
|
|
5dc0dbc7aa | ||
|
|
ee6fc4e8a1 | ||
|
|
8fee16aecd | ||
|
|
2b7ba54a2f | ||
|
|
007c3304f2 | ||
|
|
e76ba0ede9 | ||
|
|
c06ac07e23 | ||
|
|
e592a57458 | ||
|
|
66769ec657 | ||
|
|
f413feec61 | ||
|
|
2e538e3486 | ||
|
|
9617a7b0d6 | ||
|
|
7569320770 | ||
|
|
8d25cf0d75 | ||
|
|
64e85e7019 | ||
|
|
a862984dca | ||
|
|
f0365f0465 | ||
|
|
6d1e20e940 | ||
|
|
0c0aae1eac | ||
|
|
5dcf7cb846 | ||
|
|
349b2ba3af | ||
|
|
98db5aabd0 | ||
|
|
e52b542e22 | ||
|
|
8f6abb8a86 | ||
|
|
ed8eaae964 | ||
|
|
7fd98f3556 | ||
|
|
e8de87ee90 | ||
|
|
4e572ec8b9 | ||
|
|
6c7f18c448 | ||
|
|
24bc9cba67 | ||
|
|
97356b1a04 | ||
|
|
1084b53fba | ||
|
|
b1aecc2bf1 | ||
|
|
83b90e106f | ||
|
|
f52114dab2 | ||
|
|
5106caf641 | ||
|
|
12370ee84e | ||
|
|
b84ccc6e7a | ||
|
|
e19ddb53e7 | ||
|
|
5bf89dd757 | ||
|
|
2a0100b2d6 | ||
|
|
4442574e53 | ||
|
|
c020fa60d0 | ||
|
|
b078be4613 | ||
|
|
71a6dffbb6 | ||
|
|
5f65dd5bb4 | ||
|
|
27b43ed63f | ||
|
|
f6a3a1d0ba | ||
|
|
830fd8eac2 | ||
|
|
a86d501dc2 | ||
|
|
24e8e20b59 | ||
|
|
e755e567ea | ||
|
|
a87f09bad2 | ||
|
|
dbcbe48ead | ||
|
|
63908869f6 | ||
|
|
db491c8f9b | ||
|
|
f6d625114c | ||
|
|
7dc40ba6d4 | ||
|
|
fcd6475377 | ||
|
|
4070c9de81 | ||
|
|
1e9e4a86a2 | ||
|
|
406a27271a | ||
|
|
9f9a4fc2af | ||
|
|
3fc410a253 | ||
|
|
781bc1521b | ||
|
|
05d201ece8 | ||
|
|
cd0c94f48a | ||
|
|
293cc8c1a3 | ||
|
|
453e744abf | ||
|
|
653439698e | ||
|
|
24970baa57 | ||
|
|
5418bbc338 | ||
|
|
89254cfc97 | ||
|
|
6bd9a034f7 | ||
|
|
26fc65b051 | ||
|
|
ed5ec5b55c | ||
|
|
df777650ac | ||
|
|
9855615f1e | ||
|
|
93414f1baa | ||
|
|
8fac6b147a | ||
|
|
10f8c795ac | ||
|
|
3e4858a624 | ||
|
|
1231dc9cda | ||
|
|
c84ff42bcd | ||
|
|
40d78908ed | ||
|
|
8a5db02165 | ||
|
|
56fa81f3c6 | ||
|
|
d7afb6eb0c | ||
|
|
03209b35c0 | ||
|
|
bbd1fe890a | ||
|
|
843316ea7a | ||
|
|
f607231efa | ||
|
|
2039062845 | ||
|
|
44f66d2257 | ||
|
|
99478d13a8 | ||
|
|
3b51a0fe12 | ||
|
|
2d91c2a3f5 | ||
|
|
bc6c4cdbfc | ||
|
|
69d3a80fc3 | ||
|
|
404546ce93 | ||
|
|
9e268ad103 | ||
|
|
6dd1cf1dd6 | ||
|
|
9058d406a3 | ||
|
|
9d9b9e7a0d | ||
|
|
13aa82f3f3 | ||
|
|
4ea5586b6f | ||
|
|
05e55d7dc5 | ||
|
|
1b358c931c | ||
|
|
e04b02113a | ||
|
|
3275494fde | ||
|
|
e3af8783b9 | ||
|
|
ca09db21ff | ||
|
|
c1f8211acb | ||
|
|
718ff7a73f | ||
|
|
fa70b220e9 | ||
|
|
98fa2a1597 | ||
|
|
0e7c79ba23 | ||
|
|
b6ba15fcbd | ||
|
|
e44167d7a4 | ||
|
|
1bfa75f780 | ||
|
|
bbcb5552f3 | ||
|
|
31bd90c748 | ||
|
|
1b8cb7b77b | ||
|
|
774f1fbc17 | ||
|
|
cfa8ddb59f | ||
|
|
39597267ae | ||
|
|
393e38f2c0 | ||
|
|
0f646800f6 | ||
|
|
ca993238f3 | ||
|
|
d1220de02d | ||
|
|
cf9a246d53 | ||
|
|
13eb5268de | ||
|
|
88798816f2 | ||
|
|
598f0af19b | ||
|
|
a33f5d31fc | ||
|
|
54acd69e9d | ||
|
|
d687ee2777 | ||
|
|
54c2fefbad | ||
|
|
506699fba1 | ||
|
|
f7b17ee6ec | ||
|
|
408614c74c | ||
|
|
68a27772b3 | ||
|
|
de87fb622b | ||
|
|
0155a01bb1 | ||
|
|
cfeee5d511 | ||
|
|
f27672f6cf | ||
|
|
28420c14e4 | ||
|
|
10e0ea1309 | ||
|
|
0bd221ff41 | ||
|
|
5fda6f8ef3 | ||
|
|
9b956f6338 | ||
|
|
09923f654c | ||
|
|
ae7b972649 | ||
|
|
47885e3710 | ||
|
|
4b9a260b37 | ||
|
|
462a70541e | ||
|
|
2407c1f4af | ||
|
|
2c743c8f0b | ||
|
|
9f2c278ee6 | ||
|
|
aea337cfe2 | ||
|
|
811f8f8b4f | ||
|
|
27734a23b1 | ||
|
|
1b8e538a77 | ||
|
|
41c2385aca | ||
|
|
d605985f45 | ||
|
|
d52b28b147 | ||
|
|
4afe1f42ca | ||
|
|
7481c0eaa0 | ||
|
|
024bc25b2c | ||
|
|
ffdfad8482 | ||
|
|
b91ee8d008 | ||
|
|
6586f08584 | ||
|
|
92c62bb2fb | ||
|
|
f49e887fe6 | ||
|
|
344066fd11 | ||
|
|
bcb8092488 | ||
|
|
1efade8bdb | ||
|
|
a5b3ff11fd | ||
|
|
084558f200 | ||
|
|
b602eae215 | ||
|
|
d02bf9c243 | ||
|
|
26a5f67df2 | ||
|
|
600fd42a83 | ||
|
|
670685139a | ||
|
|
52b6306388 | ||
|
|
f957b8948c | ||
|
|
cd0b14dd2d | ||
|
|
894703a484 | ||
|
|
521ec6f1b8 | ||
|
|
b0c5d9640a | ||
|
|
ef8e94e992 | ||
|
|
9df96a4bb4 | ||
|
|
28a428ae2f | ||
|
|
b326ec3641 | ||
|
|
fcecbc7d46 | ||
|
|
f4007f53ba | ||
|
|
d08a2453f7 | ||
|
|
3f53eea1e0 | ||
|
|
5a812a1e93 | ||
|
|
5e624cc7b1 | ||
|
|
f3d1cc8dc1 | ||
|
|
e889efeda7 | ||
|
|
0a3a95521c | ||
|
|
4ebaf6f7a9 | ||
|
|
59ac1a3f60 | ||
|
|
3af24597ee | ||
|
|
0b834fcb54 | ||
|
|
e0be6c5786 | ||
|
|
88b101ebf5 | ||
|
|
923a5d6efb | ||
|
|
734b7e42ad | ||
|
|
d9a65745df | ||
|
|
97ab623d42 | ||
|
|
14aa6cc7e8 | ||
|
|
10e77fcf24 | ||
|
|
bbb21d7c2b | ||
|
|
3bc489254b | ||
|
|
4c07ea41c3 | ||
|
|
f6720f8dfa | ||
|
|
e19ab3a066 | ||
|
|
c46099c5d7 | ||
|
|
8f1dd69e72 | ||
|
|
f26da24a2f | ||
|
|
407020de0c | ||
|
|
8e4fbcaa7d | ||
|
|
09c339953d | ||
|
|
367a05bdf6 | ||
|
|
d20b71deb9 | ||
|
|
712ce9f781 | ||
|
|
a4a3274a55 | ||
|
|
716aa71f6e | ||
|
|
e8976f9898 | ||
|
|
8496cc2444 | ||
|
|
5ef2d59e05 | ||
|
|
07bb89ae80 | ||
|
|
27a5ad8ec2 | ||
|
|
707b07c5f5 | ||
|
|
4a764afd76 | ||
|
|
ecf49d574b | ||
|
|
188de4ff2a | ||
|
|
5a75ef8ffd | ||
|
|
07279f8746 | ||
|
|
71f788b13a | ||
|
|
59c62dc580 | ||
|
|
8fb1f114bc | ||
|
|
6a4cff6699 | ||
|
|
d5310a3300 | ||
|
|
de0ea3ac49 | ||
|
|
12116b018d | ||
|
|
c3ed3b40ea | ||
|
|
b80c2aabb0 | ||
|
|
f0a3eb574e | ||
|
|
bb15855443 | ||
|
|
14ce6aebd1 | ||
|
|
2fe83723f2 | ||
|
|
e73b9e10a6 | ||
|
|
9c04c18c04 | ||
|
|
81ae09d0ec | ||
|
|
01cf221167 | ||
|
|
cd8c86c6fb | ||
|
|
52d5fd1a67 | ||
|
|
7ecc7aabda | ||
|
|
79033aee34 | ||
|
|
b6ad243e9e | ||
|
|
92ca5078c1 | ||
|
|
aca8523060 | ||
|
|
1ea0cff3a4 | ||
|
|
75793a18f0 | ||
|
|
58866b21cb | ||
|
|
660aabc437 | ||
|
|
db80b20bc2 | ||
|
|
566120e8d5 | ||
|
|
f3f0f1717d | ||
|
|
05b499fb83 | ||
|
|
7621ec609e | ||
|
|
9f511f0024 | ||
|
|
374faa2640 | ||
|
|
ba6aa5fbbe | ||
|
|
1c52a89535 | ||
|
|
e7cedbee6e | ||
|
|
75ce0919a0 | ||
|
|
7f4f6bc9ca | ||
|
|
b8194e717c | ||
|
|
15c3cc3a50 | ||
|
|
d131435e25 | ||
|
|
6e43669498 | ||
|
|
5ab3032335 | ||
|
|
1215c635a0 | ||
|
|
54d4fd7f84 | ||
|
|
8dc690a638 | ||
|
|
fdeb84db2b | ||
|
|
84920cb670 | ||
|
|
204bba9dea | ||
|
|
35fdd7bc05 | ||
|
|
fc054db51a | ||
|
|
6e2306a5f2 | ||
|
|
b09e2115d1 | ||
|
|
6a94afab6c | ||
|
|
a68c97a40f | ||
|
|
218dc17713 | ||
|
|
cd2da152d4 | ||
|
|
28469576bf | ||
|
|
40e7f066e4 | ||
|
|
ef0edbfe69 | ||
|
|
bb6312b4fc | ||
|
|
3c315551b0 | ||
|
|
27c9c5c4da | ||
|
|
fc9f6c974a | ||
|
|
242b4d5754 | ||
|
|
4ce7c61a17 | ||
|
|
a74ee3f319 | ||
|
|
564bcbaa54 | ||
|
|
88bdd25f06 | ||
|
|
e79f65fd8e | ||
|
|
2760989401 | ||
|
|
facfe7c518 | ||
|
|
6285459c08 | ||
|
|
21bbceca0c | ||
|
|
f6300c72b7 | ||
|
|
007572b58e | ||
|
|
3a81ab22fd | ||
|
|
519da2e042 | ||
|
|
169f4295d0 | ||
|
|
d06d0eab2f | ||
|
|
3ffd120ae9 | ||
|
|
a03d514095 | ||
|
|
07d21463ca | ||
|
|
69fccf0015 | ||
|
|
1da03bfe15 | ||
|
|
6133bac226 | ||
|
|
f302be5ce6 | ||
|
|
cd4e84a360 | ||
|
|
4360ed8a7b | ||
|
|
423ce97665 | ||
|
|
b27a175fef | ||
|
|
8d5f89ccfd | ||
|
|
084e2666cb | ||
|
|
2eb2dbb266 | ||
|
|
e717939edb | ||
|
|
7758a86d1e | ||
|
|
76c563d161 | ||
|
|
a89514951f | ||
|
|
9f72a875f8 | ||
|
|
94d61c7b2b | ||
|
|
f999650322 | ||
|
|
1249b07eb8 | ||
|
|
38319b0483 | ||
|
|
6b37f33d31 | ||
|
|
af543238aa | ||
|
|
2de27c560b | ||
|
|
773ed6cc64 | ||
|
|
a594338bc5 | ||
|
|
f25f419e5a | ||
|
|
1fd1ccca17 | ||
|
|
b7e382008f | ||
|
|
70d6b95097 | ||
|
|
9b202b6c1c | ||
|
|
6a66b6801a | ||
|
|
5b6d201408 | ||
|
|
5ec9b5e5a9 | ||
|
|
5db3b58717 | ||
|
|
1fa5514d56 | ||
|
|
347769b3e3 | ||
|
|
3cfe7008a2 | ||
|
|
c353606860 | ||
|
|
2ba31ecc2d | ||
|
|
da23ddb061 | ||
|
|
39b6b3b289 | ||
|
|
c600519fa4 | ||
|
|
e5312fb5a2 | ||
|
|
36380846a7 | ||
|
|
92df0cada9 | ||
|
|
96b55acff8 | ||
|
|
608cd8ee3d | ||
|
|
9f41894573 | ||
|
|
bb45fee1cf | ||
|
|
af00304b0c | ||
|
|
5c3a013cd1 | ||
|
|
ab9e9442ec | ||
|
|
6ad188921c | ||
|
|
15ed98d6a9 | ||
|
|
a283545b6b | ||
|
|
3efbd865a8 | ||
|
|
aee659fb66 | ||
|
|
5aa386d8b9 | ||
|
|
0adc0ee6aa | ||
|
|
92f13fc316 | ||
|
|
05cfa16e5f | ||
|
|
93a6e2d920 | ||
|
|
6b49580716 | ||
|
|
de77903915 | ||
|
|
56ed0d8d90 | ||
|
|
0d4f32a881 | ||
|
|
42e818ce05 | ||
|
|
2d4c54ba54 | ||
|
|
e9eb4db8bb | ||
|
|
d26ed069fa | ||
|
|
afcab5efda | ||
|
|
1770c491db | ||
|
|
a0c6cffb0d | ||
|
|
2bf9e08b31 | ||
|
|
f56bfaa689 | ||
|
|
5d716dc796 | ||
|
|
f81ff16022 | ||
|
|
6cf1d8a947 | ||
|
|
a174d015f2 | ||
|
|
9c09128e00 | ||
|
|
68cbe20664 | ||
|
|
549c0c2c5a | ||
|
|
f092801b61 | ||
|
|
15353a6b6a | ||
|
|
1b638b3629 | ||
|
|
6f5f81753d | ||
|
|
76af454034 | ||
|
|
e54d2f6b2a | ||
|
|
bfc738b76a | ||
|
|
396899a530 | ||
|
|
04f0070a80 | ||
|
|
f383840cf9 | ||
|
|
239fc4a8c4 | ||
|
|
fd29ab418a | ||
|
|
df91408919 | ||
|
|
7a628426dc | ||
|
|
aa6c7facab | ||
|
|
8ba4c7c7ed | ||
|
|
56b4d7a76e | ||
|
|
b211c3546d | ||
|
|
edc654edf9 | ||
|
|
08586334af | ||
|
|
a4804b358f | ||
|
|
7ea14479fb | ||
|
|
54af96d321 | ||
|
|
22579155c5 | ||
|
|
c04c3832a4 | ||
|
|
5ffbd54755 | ||
|
|
5d12d4ce33 | ||
|
|
b73e53d6c4 | ||
|
|
b06463c6d9 | ||
|
|
5eb8453e91 | ||
|
|
f77c22e6ff | ||
|
|
df83ba877f | ||
|
|
9583f6b1c5 | ||
|
|
02d8a1cfec | ||
|
|
92f033dec0 | ||
|
|
0ebabf5152 | ||
|
|
4b01ecba2e | ||
|
|
d7564173dd | ||
|
|
f241124599 | ||
|
|
c44c46dd80 | ||
|
|
aa810ee719 | ||
|
|
412148af0e | ||
|
|
5d2baf6058 | ||
|
|
d28258501a | ||
|
|
55cd31fb96 | ||
|
|
d138df07bf | ||
|
|
c5df8e7897 | ||
|
|
d4d529833d | ||
|
|
caa48e7c6f | ||
|
|
acdfb3bceb | ||
|
|
89d68962b1 | ||
|
|
691cdb6bdf | ||
|
|
8064cba288 | ||
|
|
361443db10 | ||
|
|
d6352dd4d4 | ||
|
|
f8aba62860 | ||
|
|
a7eeb06f3d | ||
|
|
9426be7a5c | ||
|
|
4a135f1986 | ||
|
|
c4c02f4ad0 | ||
|
|
b87b9b455f | ||
|
|
db03ae9663 | ||
|
|
969ff6bb68 | ||
|
|
e24af6e545 | ||
|
|
bceecfb2e3 |
@@ -13,8 +13,6 @@ Dockerfile
|
||||
docs/*
|
||||
README.md
|
||||
README_CN.md
|
||||
MANAGEMENT_API.md
|
||||
MANAGEMENT_API_CN.md
|
||||
LICENSE
|
||||
|
||||
# Runtime data folders (should be mounted as volumes)
|
||||
@@ -25,6 +23,15 @@ config.yaml
|
||||
|
||||
# Development/editor
|
||||
bin/*
|
||||
.claude/*
|
||||
.vscode/*
|
||||
.claude/*
|
||||
.codex/*
|
||||
.gemini/*
|
||||
.serena/*
|
||||
.agent/*
|
||||
.agents/*
|
||||
.opencode/*
|
||||
.idea/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
_bmad-output/*
|
||||
|
||||
7
.github/ISSUE_TEMPLATE/bug_report.md
vendored
7
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -7,6 +7,13 @@ assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is it a request payload issue?**
|
||||
[ ] Yes, this is a request payload issue. I am using a client/cURL to send a request payload, but I received an unexpected error.
|
||||
[ ] No, it's another issue.
|
||||
|
||||
**If it's a request payload issue, you MUST know**
|
||||
Our team doesn't have any GODs or ORACLEs or MIND READERs. Please make sure to attach the request log or curl payload.
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
|
||||
114
.github/workflows/docker-image.yml
vendored
114
.github/workflows/docker-image.yml
vendored
@@ -1,22 +1,21 @@
|
||||
name: docker-image
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
tags:
|
||||
- v*
|
||||
|
||||
env:
|
||||
APP_NAME: CLIProxyAPI
|
||||
DOCKERHUB_REPO: eceasy/cli-proxy-api-plus
|
||||
DOCKERHUB_REPO: ${{ secrets.DOCKERHUB_USERNAME }}/cli-proxy-api-plus
|
||||
|
||||
jobs:
|
||||
docker:
|
||||
docker_amd64:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to DockerHub
|
||||
@@ -29,18 +28,113 @@ jobs:
|
||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||
- name: Build and push
|
||||
- name: Build and push (amd64)
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: |
|
||||
linux/amd64
|
||||
linux/arm64
|
||||
platforms: linux/amd64
|
||||
push: true
|
||||
build-args: |
|
||||
VERSION=${{ env.VERSION }}
|
||||
COMMIT=${{ env.COMMIT }}
|
||||
BUILD_DATE=${{ env.BUILD_DATE }}
|
||||
tags: |
|
||||
${{ env.DOCKERHUB_REPO }}:latest
|
||||
${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }}
|
||||
${{ env.DOCKERHUB_REPO }}:latest-amd64
|
||||
${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }}-amd64
|
||||
|
||||
docker_arm64:
|
||||
runs-on: ubuntu-24.04-arm
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||
- name: Build and push (arm64)
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/arm64
|
||||
push: true
|
||||
build-args: |
|
||||
VERSION=${{ env.VERSION }}
|
||||
COMMIT=${{ env.COMMIT }}
|
||||
BUILD_DATE=${{ env.BUILD_DATE }}
|
||||
tags: |
|
||||
${{ env.DOCKERHUB_REPO }}:latest-arm64
|
||||
${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }}-arm64
|
||||
|
||||
docker_manifest:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- docker_amd64
|
||||
- docker_arm64
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||
- name: Create and push multi-arch manifests
|
||||
run: |
|
||||
docker buildx imagetools create \
|
||||
--tag "${DOCKERHUB_REPO}:latest" \
|
||||
"${DOCKERHUB_REPO}:latest-amd64" \
|
||||
"${DOCKERHUB_REPO}:latest-arm64"
|
||||
docker buildx imagetools create \
|
||||
--tag "${DOCKERHUB_REPO}:${VERSION}" \
|
||||
"${DOCKERHUB_REPO}:${VERSION}-amd64" \
|
||||
"${DOCKERHUB_REPO}:${VERSION}-arm64"
|
||||
- name: Cleanup temporary tags
|
||||
continue-on-error: true
|
||||
env:
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
namespace="${DOCKERHUB_REPO%%/*}"
|
||||
repo_name="${DOCKERHUB_REPO#*/}"
|
||||
|
||||
token="$(
|
||||
curl -fsSL \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d "{\"username\":\"${DOCKERHUB_USERNAME}\",\"password\":\"${DOCKERHUB_TOKEN}\"}" \
|
||||
'https://hub.docker.com/v2/users/login/' \
|
||||
| python3 -c 'import json,sys; print(json.load(sys.stdin)["token"])'
|
||||
)"
|
||||
|
||||
delete_tag() {
|
||||
local tag="$1"
|
||||
local url="https://hub.docker.com/v2/repositories/${namespace}/${repo_name}/tags/${tag}/"
|
||||
local http_code
|
||||
http_code="$(curl -sS -o /dev/null -w "%{http_code}" -X DELETE -H "Authorization: JWT ${token}" "${url}" || true)"
|
||||
if [ "${http_code}" = "204" ] || [ "${http_code}" = "404" ]; then
|
||||
echo "Docker Hub tag removed (or missing): ${DOCKERHUB_REPO}:${tag} (HTTP ${http_code})"
|
||||
return 0
|
||||
fi
|
||||
echo "Docker Hub tag delete failed: ${DOCKERHUB_REPO}:${tag} (HTTP ${http_code})"
|
||||
return 0
|
||||
}
|
||||
|
||||
delete_tag "latest-amd64"
|
||||
delete_tag "latest-arm64"
|
||||
delete_tag "${VERSION}-amd64"
|
||||
delete_tag "${VERSION}-arm64"
|
||||
|
||||
23
.github/workflows/pr-test-build.yml
vendored
Normal file
23
.github/workflows/pr-test-build.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: pr-test-build
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: Build
|
||||
run: |
|
||||
go build -o test-output ./cmd/server
|
||||
rm -f test-output
|
||||
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- run: git fetch --force --tags
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '>=1.24.0'
|
||||
go-version: '>=1.26.0'
|
||||
cache: true
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
|
||||
23
.gitignore
vendored
23
.gitignore
vendored
@@ -1,21 +1,28 @@
|
||||
# Binaries
|
||||
cli-proxy-api
|
||||
cliproxy
|
||||
*.exe
|
||||
|
||||
|
||||
# Configuration
|
||||
config.yaml
|
||||
.env
|
||||
|
||||
.mcp.json
|
||||
# Generated content
|
||||
bin/*
|
||||
logs/*
|
||||
conv/*
|
||||
temp/*
|
||||
refs/*
|
||||
tmp/*
|
||||
|
||||
# Storage backends
|
||||
pgstore/*
|
||||
gitstore/*
|
||||
objectstore/*
|
||||
|
||||
# Static assets
|
||||
static/*
|
||||
refs/*
|
||||
|
||||
# Authentication data
|
||||
auths/*
|
||||
@@ -29,9 +36,21 @@ GEMINI.md
|
||||
|
||||
# Tooling metadata
|
||||
.vscode/*
|
||||
.codex/*
|
||||
.claude/*
|
||||
.gemini/*
|
||||
.serena/*
|
||||
.agent/*
|
||||
.agents/*
|
||||
.agents/*
|
||||
.opencode/*
|
||||
.idea/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
_bmad-output/*
|
||||
.mcp/cache/
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
._*
|
||||
*.bak
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.24-alpine AS builder
|
||||
FROM golang:1.26-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
117
README.md
117
README.md
@@ -10,13 +10,126 @@ The Plus release stays in lockstep with the mainline features.
|
||||
|
||||
## Differences from the Mainline
|
||||
|
||||
- Added GitHub Copilot support (OAuth login), provided by [em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
|
||||
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
||||
|
||||
## New Features (Plus Enhanced)
|
||||
|
||||
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $10/month. It provides access to their flagship GLM-4.7 & (GLM-5 Only Available for Pro Users)model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
|
||||
|
||||
## Kiro Authentication
|
||||
|
||||
### CLI Login
|
||||
|
||||
> **Note:** Google/GitHub login is not available for third-party applications due to AWS Cognito restrictions.
|
||||
|
||||
**AWS Builder ID** (recommended):
|
||||
|
||||
```bash
|
||||
# Device code flow
|
||||
./CLIProxyAPI --kiro-aws-login
|
||||
|
||||
# Authorization code flow
|
||||
./CLIProxyAPI --kiro-aws-authcode
|
||||
```
|
||||
|
||||
**Import token from Kiro IDE:**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-import
|
||||
```
|
||||
|
||||
To get a token from Kiro IDE:
|
||||
|
||||
1. Open Kiro IDE and login with Google (or GitHub)
|
||||
2. Find the token file: `~/.kiro/kiro-auth-token.json`
|
||||
3. Run: `./CLIProxyAPI --kiro-import`
|
||||
|
||||
**AWS IAM Identity Center (IDC):**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
||||
|
||||
# Specify region
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
||||
```
|
||||
|
||||
**Additional flags:**
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--no-browser` | Don't open browser automatically, print URL instead |
|
||||
| `--no-incognito` | Use existing browser session (Kiro defaults to incognito). Useful for corporate SSO that requires an authenticated browser session |
|
||||
| `--kiro-idc-start-url` | IDC Start URL (required with `--kiro-idc-login`) |
|
||||
| `--kiro-idc-region` | IDC region (default: `us-east-1`) |
|
||||
| `--kiro-idc-flow` | IDC flow type: `authcode` (default) or `device` |
|
||||
|
||||
### Web-based OAuth Login
|
||||
|
||||
Access the Kiro OAuth web interface at:
|
||||
|
||||
```
|
||||
http://your-server:8080/v0/oauth/kiro
|
||||
```
|
||||
|
||||
This provides a browser-based OAuth flow for Kiro (AWS CodeWhisperer) authentication with:
|
||||
- AWS Builder ID login
|
||||
- AWS Identity Center (IDC) login
|
||||
- Token import from Kiro IDE
|
||||
|
||||
## Quick Deployment with Docker
|
||||
|
||||
### One-Command Deployment
|
||||
|
||||
```bash
|
||||
# Create deployment directory
|
||||
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
||||
|
||||
# Create docker-compose.yml
|
||||
cat > docker-compose.yml << 'EOF'
|
||||
services:
|
||||
cli-proxy-api:
|
||||
image: eceasy/cli-proxy-api-plus:latest
|
||||
container_name: cli-proxy-api-plus
|
||||
ports:
|
||||
- "8317:8317"
|
||||
volumes:
|
||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
||||
- ./auths:/root/.cli-proxy-api
|
||||
- ./logs:/CLIProxyAPI/logs
|
||||
restart: unless-stopped
|
||||
EOF
|
||||
|
||||
# Download example config
|
||||
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
|
||||
|
||||
# Pull and start
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
Edit `config.yaml` before starting:
|
||||
|
||||
```yaml
|
||||
# Basic configuration example
|
||||
server:
|
||||
port: 8317
|
||||
|
||||
# Add your provider configurations here
|
||||
```
|
||||
|
||||
### Update to Latest Version
|
||||
|
||||
```bash
|
||||
cd ~/cli-proxy
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.
|
||||
|
||||
If you need to submit any non-third-party provider changes, please open them against the mainline repository.
|
||||
If you need to submit any non-third-party provider changes, please open them against the [mainline](https://github.com/router-for-me/CLIProxyAPI) repository.
|
||||
|
||||
## License
|
||||
|
||||
|
||||
117
README_CN.md
117
README_CN.md
@@ -10,13 +10,126 @@
|
||||
|
||||
## 与主线版本版本差异
|
||||
|
||||
- 新增 GitHub Copilot 支持(OAuth 登录),由[em4gp](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
|
||||
[](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
|
||||
|
||||
## 新增功能 (Plus 增强版)
|
||||
|
||||
GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7(受限于算力,目前仅限Pro用户开放),为开发者提供顶尖的编码体验。
|
||||
|
||||
智谱AI为本产品提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
|
||||
|
||||
### 命令行登录
|
||||
|
||||
> **注意:** 由于 AWS Cognito 限制,Google/GitHub 登录不可用于第三方应用。
|
||||
|
||||
**AWS Builder ID**(推荐):
|
||||
|
||||
```bash
|
||||
# 设备码流程
|
||||
./CLIProxyAPI --kiro-aws-login
|
||||
|
||||
# 授权码流程
|
||||
./CLIProxyAPI --kiro-aws-authcode
|
||||
```
|
||||
|
||||
**从 Kiro IDE 导入令牌:**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-import
|
||||
```
|
||||
|
||||
获取令牌步骤:
|
||||
|
||||
1. 打开 Kiro IDE,使用 Google(或 GitHub)登录
|
||||
2. 找到令牌文件:`~/.kiro/kiro-auth-token.json`
|
||||
3. 运行:`./CLIProxyAPI --kiro-import`
|
||||
|
||||
**AWS IAM Identity Center (IDC):**
|
||||
|
||||
```bash
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
|
||||
|
||||
# 指定区域
|
||||
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
|
||||
```
|
||||
|
||||
**附加参数:**
|
||||
|
||||
| 参数 | 说明 |
|
||||
|------|------|
|
||||
| `--no-browser` | 不自动打开浏览器,打印 URL |
|
||||
| `--no-incognito` | 使用已有浏览器会话(Kiro 默认使用无痕模式),适用于需要已登录浏览器会话的企业 SSO 场景 |
|
||||
| `--kiro-idc-start-url` | IDC Start URL(`--kiro-idc-login` 必需) |
|
||||
| `--kiro-idc-region` | IDC 区域(默认:`us-east-1`) |
|
||||
| `--kiro-idc-flow` | IDC 流程类型:`authcode`(默认)或 `device` |
|
||||
|
||||
### 网页端 OAuth 登录
|
||||
|
||||
访问 Kiro OAuth 网页认证界面:
|
||||
|
||||
```
|
||||
http://your-server:8080/v0/oauth/kiro
|
||||
```
|
||||
|
||||
提供基于浏览器的 Kiro (AWS CodeWhisperer) OAuth 认证流程,支持:
|
||||
- AWS Builder ID 登录
|
||||
- AWS Identity Center (IDC) 登录
|
||||
- 从 Kiro IDE 导入令牌
|
||||
|
||||
## Docker 快速部署
|
||||
|
||||
### 一键部署
|
||||
|
||||
```bash
|
||||
# 创建部署目录
|
||||
mkdir -p ~/cli-proxy && cd ~/cli-proxy
|
||||
|
||||
# 创建 docker-compose.yml
|
||||
cat > docker-compose.yml << 'EOF'
|
||||
services:
|
||||
cli-proxy-api:
|
||||
image: eceasy/cli-proxy-api-plus:latest
|
||||
container_name: cli-proxy-api-plus
|
||||
ports:
|
||||
- "8317:8317"
|
||||
volumes:
|
||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
||||
- ./auths:/root/.cli-proxy-api
|
||||
- ./logs:/CLIProxyAPI/logs
|
||||
restart: unless-stopped
|
||||
EOF
|
||||
|
||||
# 下载示例配置
|
||||
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
|
||||
|
||||
# 拉取并启动
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
### 配置说明
|
||||
|
||||
启动前请编辑 `config.yaml`:
|
||||
|
||||
```yaml
|
||||
# 基本配置示例
|
||||
server:
|
||||
port: 8317
|
||||
|
||||
# 在此添加你的供应商配置
|
||||
```
|
||||
|
||||
### 更新到最新版本
|
||||
|
||||
```bash
|
||||
cd ~/cli-proxy
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
## 贡献
|
||||
|
||||
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
|
||||
|
||||
如果需要提交任何非第三方供应商支持的 Pull Request,请提交到主线版本。
|
||||
如果需要提交任何非第三方供应商支持的 Pull Request,请提交到[主线](https://github.com/router-for-me/CLIProxyAPI)版本。
|
||||
|
||||
## 许可证
|
||||
|
||||
|
||||
BIN
assets/aicodemirror.png
Normal file
BIN
assets/aicodemirror.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 45 KiB |
BIN
assets/packycode.png
Normal file
BIN
assets/packycode.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.1 KiB |
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cmd"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -25,6 +27,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
@@ -47,6 +50,19 @@ func init() {
|
||||
buildinfo.BuildDate = BuildDate
|
||||
}
|
||||
|
||||
// setKiroIncognitoMode sets the incognito browser mode for Kiro authentication.
|
||||
// Kiro defaults to incognito mode for multi-account support.
|
||||
// Users can explicitly override with --incognito or --no-incognito flags.
|
||||
func setKiroIncognitoMode(cfg *config.Config, useIncognito, noIncognito bool) {
|
||||
if useIncognito {
|
||||
cfg.IncognitoBrowser = true
|
||||
} else if noIncognito {
|
||||
cfg.IncognitoBrowser = false
|
||||
} else {
|
||||
cfg.IncognitoBrowser = true // Kiro default
|
||||
}
|
||||
}
|
||||
|
||||
// main is the entry point of the application.
|
||||
// It parses command-line flags, loads configuration, and starts the appropriate
|
||||
// service based on the provided flags (login, codex-login, or server mode).
|
||||
@@ -56,32 +72,66 @@ func main() {
|
||||
// Command-line flags to control the application's behavior.
|
||||
var login bool
|
||||
var codexLogin bool
|
||||
var codexDeviceLogin bool
|
||||
var claudeLogin bool
|
||||
var qwenLogin bool
|
||||
var kiloLogin bool
|
||||
var iflowLogin bool
|
||||
var iflowCookie bool
|
||||
var noBrowser bool
|
||||
var oauthCallbackPort int
|
||||
var antigravityLogin bool
|
||||
var kimiLogin bool
|
||||
var kiroLogin bool
|
||||
var kiroGoogleLogin bool
|
||||
var kiroAWSLogin bool
|
||||
var kiroAWSAuthCode bool
|
||||
var kiroImport bool
|
||||
var kiroIDCLogin bool
|
||||
var kiroIDCStartURL string
|
||||
var kiroIDCRegion string
|
||||
var kiroIDCFlow string
|
||||
var githubCopilotLogin bool
|
||||
var projectID string
|
||||
var vertexImport string
|
||||
var configPath string
|
||||
var password string
|
||||
var tuiMode bool
|
||||
var standalone bool
|
||||
var noIncognito bool
|
||||
var useIncognito bool
|
||||
|
||||
// Define command-line flags for different operation modes.
|
||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
||||
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
|
||||
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
||||
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
||||
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
||||
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
||||
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
|
||||
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
||||
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
||||
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
||||
flag.BoolVar(&kiroAWSAuthCode, "kiro-aws-authcode", false, "Login to Kiro using AWS Builder ID (authorization code flow, better UX)")
|
||||
flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)")
|
||||
flag.BoolVar(&kiroIDCLogin, "kiro-idc-login", false, "Login to Kiro using IAM Identity Center (IDC)")
|
||||
flag.StringVar(&kiroIDCStartURL, "kiro-idc-start-url", "", "IDC start URL (required with --kiro-idc-login)")
|
||||
flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)")
|
||||
flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device")
|
||||
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||
flag.StringVar(&password, "password", "", "")
|
||||
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
||||
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
||||
|
||||
flag.CommandLine.Usage = func() {
|
||||
out := flag.CommandLine.Output()
|
||||
@@ -141,7 +191,8 @@ func main() {
|
||||
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to get working directory: %v", err)
|
||||
log.Errorf("failed to get working directory: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Load environment variables from .env if present.
|
||||
@@ -235,13 +286,15 @@ func main() {
|
||||
})
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize postgres token store: %v", err)
|
||||
log.Errorf("failed to initialize postgres token store: %v", err)
|
||||
return
|
||||
}
|
||||
examplePath := filepath.Join(wd, "config.example.yaml")
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second)
|
||||
if errBootstrap := pgStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
|
||||
cancel()
|
||||
log.Fatalf("failed to bootstrap postgres-backed config: %v", errBootstrap)
|
||||
log.Errorf("failed to bootstrap postgres-backed config: %v", errBootstrap)
|
||||
return
|
||||
}
|
||||
cancel()
|
||||
configFilePath = pgStoreInst.ConfigPath()
|
||||
@@ -264,7 +317,8 @@ func main() {
|
||||
if strings.Contains(resolvedEndpoint, "://") {
|
||||
parsed, errParse := url.Parse(resolvedEndpoint)
|
||||
if errParse != nil {
|
||||
log.Fatalf("failed to parse object store endpoint %q: %v", objectStoreEndpoint, errParse)
|
||||
log.Errorf("failed to parse object store endpoint %q: %v", objectStoreEndpoint, errParse)
|
||||
return
|
||||
}
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "http":
|
||||
@@ -272,10 +326,12 @@ func main() {
|
||||
case "https":
|
||||
useSSL = true
|
||||
default:
|
||||
log.Fatalf("unsupported object store scheme %q (only http and https are allowed)", parsed.Scheme)
|
||||
log.Errorf("unsupported object store scheme %q (only http and https are allowed)", parsed.Scheme)
|
||||
return
|
||||
}
|
||||
if parsed.Host == "" {
|
||||
log.Fatalf("object store endpoint %q is missing host information", objectStoreEndpoint)
|
||||
log.Errorf("object store endpoint %q is missing host information", objectStoreEndpoint)
|
||||
return
|
||||
}
|
||||
resolvedEndpoint = parsed.Host
|
||||
if parsed.Path != "" && parsed.Path != "/" {
|
||||
@@ -294,13 +350,15 @@ func main() {
|
||||
}
|
||||
objectStoreInst, err = store.NewObjectTokenStore(objCfg)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize object token store: %v", err)
|
||||
log.Errorf("failed to initialize object token store: %v", err)
|
||||
return
|
||||
}
|
||||
examplePath := filepath.Join(wd, "config.example.yaml")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
if errBootstrap := objectStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
|
||||
cancel()
|
||||
log.Fatalf("failed to bootstrap object-backed config: %v", errBootstrap)
|
||||
log.Errorf("failed to bootstrap object-backed config: %v", errBootstrap)
|
||||
return
|
||||
}
|
||||
cancel()
|
||||
configFilePath = objectStoreInst.ConfigPath()
|
||||
@@ -325,7 +383,8 @@ func main() {
|
||||
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword)
|
||||
gitStoreInst.SetBaseDir(authDir)
|
||||
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
|
||||
log.Fatalf("failed to prepare git token store: %v", errRepo)
|
||||
log.Errorf("failed to prepare git token store: %v", errRepo)
|
||||
return
|
||||
}
|
||||
configFilePath = gitStoreInst.ConfigPath()
|
||||
if configFilePath == "" {
|
||||
@@ -334,17 +393,21 @@ func main() {
|
||||
if _, statErr := os.Stat(configFilePath); errors.Is(statErr, fs.ErrNotExist) {
|
||||
examplePath := filepath.Join(wd, "config.example.yaml")
|
||||
if _, errExample := os.Stat(examplePath); errExample != nil {
|
||||
log.Fatalf("failed to find template config file: %v", errExample)
|
||||
log.Errorf("failed to find template config file: %v", errExample)
|
||||
return
|
||||
}
|
||||
if errCopy := misc.CopyConfigTemplate(examplePath, configFilePath); errCopy != nil {
|
||||
log.Fatalf("failed to bootstrap git-backed config: %v", errCopy)
|
||||
log.Errorf("failed to bootstrap git-backed config: %v", errCopy)
|
||||
return
|
||||
}
|
||||
if errCommit := gitStoreInst.PersistConfig(context.Background()); errCommit != nil {
|
||||
log.Fatalf("failed to commit initial git-backed config: %v", errCommit)
|
||||
log.Errorf("failed to commit initial git-backed config: %v", errCommit)
|
||||
return
|
||||
}
|
||||
log.Infof("git-backed config initialized from template: %s", configFilePath)
|
||||
} else if statErr != nil {
|
||||
log.Fatalf("failed to inspect git-backed config: %v", statErr)
|
||||
log.Errorf("failed to inspect git-backed config: %v", statErr)
|
||||
return
|
||||
}
|
||||
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
|
||||
if err == nil {
|
||||
@@ -357,13 +420,15 @@ func main() {
|
||||
} else {
|
||||
wd, err = os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to get working directory: %v", err)
|
||||
log.Errorf("failed to get working directory: %v", err)
|
||||
return
|
||||
}
|
||||
configFilePath = filepath.Join(wd, "config.yaml")
|
||||
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatalf("failed to load config: %v", err)
|
||||
log.Errorf("failed to load config: %v", err)
|
||||
return
|
||||
}
|
||||
if cfg == nil {
|
||||
cfg = &config.Config{}
|
||||
@@ -392,8 +457,9 @@ func main() {
|
||||
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
||||
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
|
||||
if err = logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
||||
log.Fatalf("failed to configure log output: %v", err)
|
||||
if err = logging.ConfigureLogOutput(cfg); err != nil {
|
||||
log.Errorf("failed to configure log output: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate)
|
||||
@@ -402,7 +468,8 @@ func main() {
|
||||
util.SetLogLevel(cfg)
|
||||
|
||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
||||
log.Fatalf("failed to resolve auth directory: %v", errResolveAuthDir)
|
||||
log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir)
|
||||
return
|
||||
} else {
|
||||
cfg.AuthDir = resolvedAuthDir
|
||||
}
|
||||
@@ -410,7 +477,8 @@ func main() {
|
||||
|
||||
// Create login options to be used in authentication flows.
|
||||
options := &cmd.LoginOptions{
|
||||
NoBrowser: noBrowser,
|
||||
NoBrowser: noBrowser,
|
||||
CallbackPort: oauthCallbackPort,
|
||||
}
|
||||
|
||||
// Register the shared token store once so all components use the same persistence backend.
|
||||
@@ -425,7 +493,7 @@ func main() {
|
||||
}
|
||||
|
||||
// Register built-in access providers before constructing services.
|
||||
configaccess.Register()
|
||||
configaccess.Register(&cfg.SDKConfig)
|
||||
|
||||
// Handle different command modes based on the provided flags.
|
||||
|
||||
@@ -444,15 +512,56 @@ func main() {
|
||||
} else if codexLogin {
|
||||
// Handle Codex login
|
||||
cmd.DoCodexLogin(cfg, options)
|
||||
} else if codexDeviceLogin {
|
||||
// Handle Codex device-code login
|
||||
cmd.DoCodexDeviceLogin(cfg, options)
|
||||
} else if claudeLogin {
|
||||
// Handle Claude login
|
||||
cmd.DoClaudeLogin(cfg, options)
|
||||
} else if qwenLogin {
|
||||
cmd.DoQwenLogin(cfg, options)
|
||||
} else if kiloLogin {
|
||||
cmd.DoKiloLogin(cfg, options)
|
||||
} else if iflowLogin {
|
||||
cmd.DoIFlowLogin(cfg, options)
|
||||
} else if iflowCookie {
|
||||
cmd.DoIFlowCookieAuth(cfg, options)
|
||||
} else if kimiLogin {
|
||||
cmd.DoKimiLogin(cfg, options)
|
||||
} else if kiroLogin {
|
||||
// For Kiro auth, default to incognito mode for multi-account support
|
||||
// Users can explicitly override with --no-incognito
|
||||
// Note: This config mutation is safe - auth commands exit after completion
|
||||
// and don't share config with StartService (which is in the else branch)
|
||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||
kiro.InitFingerprintConfig(cfg)
|
||||
cmd.DoKiroLogin(cfg, options)
|
||||
} else if kiroGoogleLogin {
|
||||
// For Kiro auth, default to incognito mode for multi-account support
|
||||
// Users can explicitly override with --no-incognito
|
||||
// Note: This config mutation is safe - auth commands exit after completion
|
||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||
kiro.InitFingerprintConfig(cfg)
|
||||
cmd.DoKiroGoogleLogin(cfg, options)
|
||||
} else if kiroAWSLogin {
|
||||
// For Kiro auth, default to incognito mode for multi-account support
|
||||
// Users can explicitly override with --no-incognito
|
||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||
kiro.InitFingerprintConfig(cfg)
|
||||
cmd.DoKiroAWSLogin(cfg, options)
|
||||
} else if kiroAWSAuthCode {
|
||||
// For Kiro auth with authorization code flow (better UX)
|
||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||
kiro.InitFingerprintConfig(cfg)
|
||||
cmd.DoKiroAWSAuthCodeLogin(cfg, options)
|
||||
} else if kiroImport {
|
||||
kiro.InitFingerprintConfig(cfg)
|
||||
cmd.DoKiroImport(cfg, options)
|
||||
} else if kiroIDCLogin {
|
||||
// For Kiro IDC auth, default to incognito mode for multi-account support
|
||||
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||
kiro.InitFingerprintConfig(cfg)
|
||||
cmd.DoKiroIDCLogin(cfg, options, kiroIDCStartURL, kiroIDCRegion, kiroIDCFlow)
|
||||
} else {
|
||||
// In cloud deploy mode without config file, just wait for shutdown signals
|
||||
if isCloudDeploy && !configFileExists {
|
||||
@@ -460,8 +569,89 @@ func main() {
|
||||
cmd.WaitForCloudDeploy()
|
||||
return
|
||||
}
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
cmd.StartService(cfg, configFilePath, password)
|
||||
if tuiMode {
|
||||
if standalone {
|
||||
// Standalone mode: start an embedded local server and connect TUI client to it.
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
hook := tui.NewLogHook(2000)
|
||||
hook.SetFormatter(&logging.LogFormatter{})
|
||||
log.AddHook(hook)
|
||||
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
origLogOutput := log.StandardLogger().Out
|
||||
log.SetOutput(io.Discard)
|
||||
|
||||
devNull, errOpenDevNull := os.Open(os.DevNull)
|
||||
if errOpenDevNull == nil {
|
||||
os.Stdout = devNull
|
||||
os.Stderr = devNull
|
||||
}
|
||||
|
||||
restoreIO := func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
log.SetOutput(origLogOutput)
|
||||
if devNull != nil {
|
||||
_ = devNull.Close()
|
||||
}
|
||||
}
|
||||
|
||||
localMgmtPassword := fmt.Sprintf("tui-%d-%d", os.Getpid(), time.Now().UnixNano())
|
||||
if password == "" {
|
||||
password = localMgmtPassword
|
||||
}
|
||||
|
||||
cancel, done := cmd.StartServiceBackground(cfg, configFilePath, password)
|
||||
|
||||
client := tui.NewClient(cfg.Port, password)
|
||||
ready := false
|
||||
backoff := 100 * time.Millisecond
|
||||
for i := 0; i < 30; i++ {
|
||||
if _, errGetConfig := client.GetConfig(); errGetConfig == nil {
|
||||
ready = true
|
||||
break
|
||||
}
|
||||
time.Sleep(backoff)
|
||||
if backoff < time.Second {
|
||||
backoff = time.Duration(float64(backoff) * 1.5)
|
||||
}
|
||||
}
|
||||
|
||||
if !ready {
|
||||
restoreIO()
|
||||
cancel()
|
||||
<-done
|
||||
fmt.Fprintf(os.Stderr, "TUI error: embedded server is not ready\n")
|
||||
return
|
||||
}
|
||||
|
||||
if errRun := tui.Run(cfg.Port, password, hook, origStdout); errRun != nil {
|
||||
restoreIO()
|
||||
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
|
||||
} else {
|
||||
restoreIO()
|
||||
}
|
||||
|
||||
cancel()
|
||||
<-done
|
||||
} else {
|
||||
// Default TUI mode: pure management client.
|
||||
// The proxy server must already be running.
|
||||
if errRun := tui.Run(cfg.Port, password, nil, os.Stdout); errRun != nil {
|
||||
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Start the main proxy service
|
||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||
|
||||
if cfg.AuthDir != "" {
|
||||
kiro.InitializeAndStart(cfg.AuthDir, cfg)
|
||||
defer kiro.StopGlobalRefreshManager()
|
||||
}
|
||||
|
||||
cmd.StartService(cfg, configFilePath, password)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
# Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6).
|
||||
# Use "127.0.0.1" or "localhost" to restrict access to local machine only.
|
||||
host: ''
|
||||
|
||||
# Server port
|
||||
port: 8317
|
||||
|
||||
# TLS settings for HTTPS. When enabled, the server listens with the provided certificate and key.
|
||||
tls:
|
||||
enable: false
|
||||
cert: ""
|
||||
key: ""
|
||||
cert: ''
|
||||
key: ''
|
||||
|
||||
# Management API settings
|
||||
remote-management:
|
||||
@@ -16,34 +20,70 @@ remote-management:
|
||||
# Management key. If a plaintext value is provided here, it will be hashed on startup.
|
||||
# All management requests (even from localhost) require this key.
|
||||
# Leave empty to disable the Management API entirely (404 for all /v0/management routes).
|
||||
secret-key: ""
|
||||
secret-key: ''
|
||||
|
||||
# Disable the bundled management control panel asset download and HTTP route when true.
|
||||
disable-control-panel: false
|
||||
|
||||
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
|
||||
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
|
||||
|
||||
# Authentication directory (supports ~ for home directory)
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
auth-dir: '~/.cli-proxy-api'
|
||||
|
||||
# API keys for authentication
|
||||
api-keys:
|
||||
- "your-api-key-1"
|
||||
- "your-api-key-2"
|
||||
- 'your-api-key-1'
|
||||
- 'your-api-key-2'
|
||||
- 'your-api-key-3'
|
||||
|
||||
# Enable debug logging
|
||||
debug: false
|
||||
|
||||
# Enable pprof HTTP debug server (host:port). Keep it bound to localhost for safety.
|
||||
pprof:
|
||||
enable: false
|
||||
addr: '127.0.0.1:8316'
|
||||
|
||||
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
|
||||
commercial-mode: false
|
||||
|
||||
# Open OAuth URLs in incognito/private browser mode.
|
||||
# Useful when you want to login with a different account without logging out from your current session.
|
||||
# Default: false (but Kiro auth defaults to true for multi-account support)
|
||||
incognito-browser: true
|
||||
|
||||
# When true, write application logs to rotating files instead of stdout
|
||||
logging-to-file: false
|
||||
|
||||
# Maximum total size (MB) of log files under the logs directory. When exceeded, the oldest log
|
||||
# files are deleted until within the limit. Set to 0 to disable.
|
||||
logs-max-total-size-mb: 0
|
||||
|
||||
# Maximum number of error log files retained when request logging is disabled.
|
||||
# When exceeded, the oldest error log files are deleted. Default is 10. Set to 0 to disable cleanup.
|
||||
error-logs-max-files: 10
|
||||
|
||||
# When false, disable in-memory usage statistics aggregation
|
||||
usage-statistics-enabled: false
|
||||
|
||||
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
||||
proxy-url: ""
|
||||
proxy-url: ''
|
||||
|
||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||
force-model-prefix: false
|
||||
|
||||
# When true, forward filtered upstream response headers to downstream clients.
|
||||
# Default is false (disabled).
|
||||
passthrough-headers: false
|
||||
|
||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||
request-retry: 3
|
||||
|
||||
# Maximum number of different credentials to try for one failed request.
|
||||
# Set to 0 to keep legacy behavior (try all available credentials).
|
||||
max-retry-credentials: 0
|
||||
|
||||
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
|
||||
max-retry-interval: 30
|
||||
|
||||
@@ -52,16 +92,32 @@ quota-exceeded:
|
||||
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
||||
|
||||
# Routing strategy for selecting credentials when multiple match.
|
||||
routing:
|
||||
strategy: 'round-robin' # round-robin (default), fill-first
|
||||
|
||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||
ws-auth: false
|
||||
|
||||
# When > 0, emit blank lines every N seconds for non-streaming responses to prevent idle timeouts.
|
||||
nonstream-keepalive-interval: 0
|
||||
|
||||
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
|
||||
# streaming:
|
||||
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
|
||||
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
|
||||
|
||||
# Gemini API keys
|
||||
# gemini-api-key:
|
||||
# - api-key: "AIzaSy...01"
|
||||
# prefix: "test" # optional: require calls like "test/gemini-3-pro-preview" to target this credential
|
||||
# base-url: "https://generativelanguage.googleapis.com"
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080"
|
||||
# models:
|
||||
# - name: "gemini-2.5-flash" # upstream model name
|
||||
# alias: "gemini-flash" # client alias mapped to the upstream model
|
||||
# excluded-models:
|
||||
# - "gemini-2.5-pro" # exclude specific models from this provider (exact match)
|
||||
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
|
||||
@@ -72,10 +128,14 @@ ws-auth: false
|
||||
# Codex API keys
|
||||
# codex-api-key:
|
||||
# - api-key: "sk-atSM..."
|
||||
# prefix: "test" # optional: require calls like "test/gpt-5-codex" to target this credential
|
||||
# base-url: "https://www.example.com" # use the custom codex API endpoint
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# models:
|
||||
# - name: "gpt-5-codex" # upstream model name
|
||||
# alias: "codex-latest" # client alias mapped to the upstream model
|
||||
# excluded-models:
|
||||
# - "gpt-5.1" # exclude specific models (exact match)
|
||||
# - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex)
|
||||
@@ -86,22 +146,69 @@ ws-auth: false
|
||||
# claude-api-key:
|
||||
# - api-key: "sk-atSM..." # use the official claude API key, no need to set the base url
|
||||
# - api-key: "sk-atSM..."
|
||||
# prefix: "test" # optional: require calls like "test/claude-sonnet-latest" to target this credential
|
||||
# base-url: "https://www.example.com" # use the custom claude API endpoint
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# models:
|
||||
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||
# excluded-models:
|
||||
# - "claude-opus-4-5-20251101" # exclude specific models (exact match)
|
||||
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
|
||||
# - "*-think" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
|
||||
# - "*-thinking" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
|
||||
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
|
||||
# cloak: # optional: request cloaking for non-Claude-Code clients
|
||||
# mode: "auto" # "auto" (default): cloak only when client is not Claude Code
|
||||
# # "always": always apply cloaking
|
||||
# # "never": never apply cloaking
|
||||
# strict-mode: false # false (default): prepend Claude Code prompt to user system messages
|
||||
# # true: strip all user system messages, keep only Claude Code prompt
|
||||
# sensitive-words: # optional: words to obfuscate with zero-width characters
|
||||
# - "API"
|
||||
# - "proxy"
|
||||
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
||||
|
||||
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
||||
# These are used as fallbacks when the client does not send its own headers.
|
||||
# claude-header-defaults:
|
||||
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
|
||||
# package-version: "0.74.0"
|
||||
# runtime-version: "v24.3.0"
|
||||
# timeout: "600"
|
||||
|
||||
# Kiro (AWS CodeWhisperer) configuration
|
||||
# Note: Kiro API currently only operates in us-east-1 region
|
||||
#kiro:
|
||||
# - token-file: "~/.aws/sso/cache/kiro-auth-token.json" # path to Kiro token file
|
||||
# agent-task-type: "" # optional: "vibe" or empty (API default)
|
||||
# start-url: "https://your-company.awsapps.com/start" # optional: IDC start URL (preset for login)
|
||||
# region: "us-east-1" # optional: OIDC region for IDC login and token refresh
|
||||
# - access-token: "aoaAAAAA..." # or provide tokens directly
|
||||
# refresh-token: "aorAAAAA..."
|
||||
# profile-arn: "arn:aws:codewhisperer:us-east-1:..."
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy override
|
||||
|
||||
# Kilocode (OAuth-based code assistant)
|
||||
# Note: Kilocode uses OAuth device flow authentication.
|
||||
# Use the CLI command: ./server --kilo-login
|
||||
# This will save credentials to the auth directory (default: ~/.cli-proxy-api/)
|
||||
# oauth-model-alias:
|
||||
# kilo:
|
||||
# - name: "minimax/minimax-m2.5:free"
|
||||
# alias: "minimax-m2.5"
|
||||
# - name: "z-ai/glm-5:free"
|
||||
# alias: "glm-5"
|
||||
# oauth-excluded-models:
|
||||
# kilo:
|
||||
# - "kilo-claude-opus-4-6" # exclude specific models (exact match)
|
||||
# - "*:free" # wildcard matching suffix (e.g. all free models)
|
||||
|
||||
# OpenAI compatibility providers
|
||||
# openai-compatibility:
|
||||
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
|
||||
# prefix: "test" # optional: require calls like "test/kimi-k2" to target this provider's credentials
|
||||
# base-url: "https://openrouter.ai/api/v1" # The base URL of the provider.
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -112,19 +219,34 @@ ws-auth: false
|
||||
# models: # The models supported by the provider.
|
||||
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
||||
# alias: "kimi-k2" # The alias used in the API.
|
||||
# # You may repeat the same alias to build an internal model pool.
|
||||
# # The client still sees only one alias in the model list.
|
||||
# # Requests to that alias will round-robin across the upstream names below,
|
||||
# # and if the chosen upstream fails before producing output, the request will
|
||||
# # continue with the next upstream model in the same alias pool.
|
||||
# - name: "qwen3.5-plus"
|
||||
# alias: "claude-opus-4.66"
|
||||
# - name: "glm-5"
|
||||
# alias: "claude-opus-4.66"
|
||||
# - name: "kimi-k2.5"
|
||||
# alias: "claude-opus-4.66"
|
||||
|
||||
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
||||
# vertex-api-key:
|
||||
# - api-key: "vk-123..." # x-goog-api-key header
|
||||
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
||||
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# models: # optional: map aliases to upstream model names
|
||||
# - name: "gemini-2.0-flash" # upstream model name
|
||||
# - name: "gemini-2.5-flash" # upstream model name
|
||||
# alias: "vertex-flash" # client-visible alias
|
||||
# - name: "gemini-1.5-pro"
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "vertex-pro"
|
||||
# excluded-models: # optional: models to exclude from listing
|
||||
# - "imagen-3.0-generate-002"
|
||||
# - "imagen-*"
|
||||
|
||||
# Amp Integration
|
||||
# ampcode:
|
||||
@@ -132,21 +254,89 @@ ws-auth: false
|
||||
# upstream-url: "https://ampcode.com"
|
||||
# # Optional: Override API key for Amp upstream (otherwise uses env or file)
|
||||
# upstream-api-key: ""
|
||||
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended)
|
||||
# restrict-management-to-localhost: true
|
||||
# # Per-client upstream API key mapping
|
||||
# # Maps client API keys (from top-level api-keys) to different Amp upstream API keys.
|
||||
# # Useful when different clients need to use different Amp accounts/quotas.
|
||||
# # If a client key isn't mapped, falls back to upstream-api-key (default behavior).
|
||||
# upstream-api-keys:
|
||||
# - upstream-api-key: "amp_key_for_team_a" # Upstream key to use for these clients
|
||||
# api-keys: # Client keys that use this upstream key
|
||||
# - "your-api-key-1"
|
||||
# - "your-api-key-2"
|
||||
# - upstream-api-key: "amp_key_for_team_b"
|
||||
# api-keys:
|
||||
# - "your-api-key-3"
|
||||
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
|
||||
# restrict-management-to-localhost: false
|
||||
# # Force model mappings to run before checking local API keys (default: false)
|
||||
# force-model-mappings: false
|
||||
# # Amp Model Mappings
|
||||
# # Route unavailable Amp models to alternative models available in your local proxy.
|
||||
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
|
||||
# # but you have a similar model available (e.g., Claude Sonnet 4).
|
||||
# model-mappings:
|
||||
# - from: "claude-opus-4.5" # Model requested by Amp CLI
|
||||
# to: "claude-sonnet-4" # Route to this available model instead
|
||||
# - from: "gpt-5"
|
||||
# to: "gemini-2.5-pro"
|
||||
# - from: "claude-3-opus-20240229"
|
||||
# to: "claude-3-5-sonnet-20241022"
|
||||
# - from: "claude-opus-4-5-20251101" # Model requested by Amp CLI
|
||||
# to: "gemini-claude-opus-4-5-thinking" # Route to this available model instead
|
||||
# - from: "claude-sonnet-4-5-20250929"
|
||||
# to: "gemini-claude-sonnet-4-5-thinking"
|
||||
# - from: "claude-haiku-4-5-20251001"
|
||||
# to: "gemini-2.5-flash"
|
||||
|
||||
# Global OAuth model name aliases (per channel)
|
||||
# These aliases rename model IDs for both model listing and request routing.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
|
||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||
# oauth-model-alias:
|
||||
# antigravity:
|
||||
# - name: "rev19-uic3-1p"
|
||||
# alias: "gemini-2.5-computer-use-preview-10-2025"
|
||||
# - name: "gemini-3-pro-image"
|
||||
# alias: "gemini-3-pro-image-preview"
|
||||
# - name: "gemini-3-pro-high"
|
||||
# alias: "gemini-3-pro-preview"
|
||||
# - name: "gemini-3-flash"
|
||||
# alias: "gemini-3-flash-preview"
|
||||
# - name: "claude-sonnet-4-5"
|
||||
# alias: "gemini-claude-sonnet-4-5"
|
||||
# - name: "claude-sonnet-4-5-thinking"
|
||||
# alias: "gemini-claude-sonnet-4-5-thinking"
|
||||
# - name: "claude-opus-4-5-thinking"
|
||||
# alias: "gemini-claude-opus-4-5-thinking"
|
||||
# gemini-cli:
|
||||
# - name: "gemini-2.5-pro" # original model name under this channel
|
||||
# alias: "g2.5p" # client-visible alias
|
||||
# fork: true # when true, keep original and also add the alias as an extra model (default: false)
|
||||
# vertex:
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "g2.5p"
|
||||
# aistudio:
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "g2.5p"
|
||||
# claude:
|
||||
# - name: "claude-sonnet-4-5-20250929"
|
||||
# alias: "cs4.5"
|
||||
# codex:
|
||||
# - name: "gpt-5"
|
||||
# alias: "g5"
|
||||
# qwen:
|
||||
# - name: "qwen3-coder-plus"
|
||||
# alias: "qwen-plus"
|
||||
# iflow:
|
||||
# - name: "glm-4.7"
|
||||
# alias: "glm-god"
|
||||
# kimi:
|
||||
# - name: "kimi-k2.5"
|
||||
# alias: "k2.5"
|
||||
# kiro:
|
||||
# - name: "kiro-claude-opus-4-5"
|
||||
# alias: "op45"
|
||||
# github-copilot:
|
||||
# - name: "gpt-5"
|
||||
# alias: "copilot-gpt5"
|
||||
|
||||
# OAuth provider excluded models
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
|
||||
# oauth-excluded-models:
|
||||
# gemini-cli:
|
||||
# - "gemini-2.5-pro" # exclude specific models (exact match)
|
||||
@@ -167,18 +357,43 @@ ws-auth: false
|
||||
# - "vision-model"
|
||||
# iflow:
|
||||
# - "tstars2.0"
|
||||
# kimi:
|
||||
# - "kimi-k2-thinking"
|
||||
# kiro:
|
||||
# - "kiro-claude-haiku-4-5"
|
||||
# github-copilot:
|
||||
# - "raptor-mini"
|
||||
|
||||
# Optional payload configuration
|
||||
# payload:
|
||||
# default: # Default rules only set parameters when they are missing in the payload.
|
||||
# - models:
|
||||
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
|
||||
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
|
||||
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
|
||||
# params: # JSON path (gjson/sjson syntax) -> value
|
||||
# "generationConfig.thinkingConfig.thinkingBudget": 32768
|
||||
# default-raw: # Default raw rules set parameters using raw JSON when missing (must be valid JSON).
|
||||
# - models:
|
||||
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
|
||||
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
|
||||
# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON)
|
||||
# "generationConfig.responseJsonSchema": "{\"type\":\"object\",\"properties\":{\"answer\":{\"type\":\"string\"}}}"
|
||||
# override: # Override rules always set parameters, overwriting any existing values.
|
||||
# - models:
|
||||
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
|
||||
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
|
||||
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
|
||||
# params: # JSON path (gjson/sjson syntax) -> value
|
||||
# "reasoning.effort": "high"
|
||||
# override-raw: # Override raw rules always set parameters using raw JSON (must be valid JSON).
|
||||
# - models:
|
||||
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
|
||||
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
|
||||
# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON)
|
||||
# "response_format": "{\"type\":\"json_schema\",\"json_schema\":{\"name\":\"answer\",\"schema\":{\"type\":\"object\"}}}"
|
||||
# filter: # Filter rules remove specified parameters from the payload.
|
||||
# - models:
|
||||
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
|
||||
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
|
||||
# params: # JSON paths (gjson/sjson syntax) to remove from the payload
|
||||
# - "generationConfig.thinkingConfig.thinkingBudget"
|
||||
# - "generationConfig.responseJsonSchema"
|
||||
|
||||
128
docker-build.sh
128
docker-build.sh
@@ -5,9 +5,115 @@
|
||||
# This script automates the process of building and running the Docker container
|
||||
# with version information dynamically injected at build time.
|
||||
|
||||
# Exit immediately if a command exits with a non-zero status.
|
||||
# Hidden feature: Preserve usage statistics across rebuilds
|
||||
# Usage: ./docker-build.sh --with-usage
|
||||
# First run prompts for management API key, saved to temp/stats/.api_secret
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
STATS_DIR="temp/stats"
|
||||
STATS_FILE="${STATS_DIR}/.usage_backup.json"
|
||||
SECRET_FILE="${STATS_DIR}/.api_secret"
|
||||
WITH_USAGE=false
|
||||
|
||||
get_port() {
|
||||
if [[ -f "config.yaml" ]]; then
|
||||
grep -E "^port:" config.yaml | sed -E 's/^port: *["'"'"']?([0-9]+)["'"'"']?.*$/\1/'
|
||||
else
|
||||
echo "8317"
|
||||
fi
|
||||
}
|
||||
|
||||
export_stats_api_secret() {
|
||||
if [[ -f "${SECRET_FILE}" ]]; then
|
||||
API_SECRET=$(cat "${SECRET_FILE}")
|
||||
else
|
||||
if [[ ! -d "${STATS_DIR}" ]]; then
|
||||
mkdir -p "${STATS_DIR}"
|
||||
fi
|
||||
echo "First time using --with-usage. Management API key required."
|
||||
read -r -p "Enter management key: " -s API_SECRET
|
||||
echo
|
||||
echo "${API_SECRET}" > "${SECRET_FILE}"
|
||||
chmod 600 "${SECRET_FILE}"
|
||||
fi
|
||||
}
|
||||
|
||||
check_container_running() {
|
||||
local port
|
||||
port=$(get_port)
|
||||
|
||||
if ! curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
|
||||
echo "Error: cli-proxy-api service is not responding at localhost:${port}"
|
||||
echo "Please start the container first or use without --with-usage flag."
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
export_stats() {
|
||||
local port
|
||||
port=$(get_port)
|
||||
|
||||
if [[ ! -d "${STATS_DIR}" ]]; then
|
||||
mkdir -p "${STATS_DIR}"
|
||||
fi
|
||||
check_container_running
|
||||
echo "Exporting usage statistics..."
|
||||
EXPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -H "X-Management-Key: ${API_SECRET}" \
|
||||
"http://localhost:${port}/v0/management/usage/export")
|
||||
HTTP_CODE=$(echo "${EXPORT_RESPONSE}" | tail -n1)
|
||||
RESPONSE_BODY=$(echo "${EXPORT_RESPONSE}" | sed '$d')
|
||||
|
||||
if [[ "${HTTP_CODE}" != "200" ]]; then
|
||||
echo "Export failed (HTTP ${HTTP_CODE}): ${RESPONSE_BODY}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "${RESPONSE_BODY}" > "${STATS_FILE}"
|
||||
echo "Statistics exported to ${STATS_FILE}"
|
||||
}
|
||||
|
||||
import_stats() {
|
||||
local port
|
||||
port=$(get_port)
|
||||
|
||||
echo "Importing usage statistics..."
|
||||
IMPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST \
|
||||
-H "X-Management-Key: ${API_SECRET}" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d @"${STATS_FILE}" \
|
||||
"http://localhost:${port}/v0/management/usage/import")
|
||||
IMPORT_CODE=$(echo "${IMPORT_RESPONSE}" | tail -n1)
|
||||
IMPORT_BODY=$(echo "${IMPORT_RESPONSE}" | sed '$d')
|
||||
|
||||
if [[ "${IMPORT_CODE}" == "200" ]]; then
|
||||
echo "Statistics imported successfully"
|
||||
else
|
||||
echo "Import failed (HTTP ${IMPORT_CODE}): ${IMPORT_BODY}"
|
||||
fi
|
||||
|
||||
rm -f "${STATS_FILE}"
|
||||
}
|
||||
|
||||
wait_for_service() {
|
||||
local port
|
||||
port=$(get_port)
|
||||
|
||||
echo "Waiting for service to be ready..."
|
||||
for i in {1..30}; do
|
||||
if curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
sleep 2
|
||||
}
|
||||
|
||||
if [[ "${1:-}" == "--with-usage" ]]; then
|
||||
WITH_USAGE=true
|
||||
export_stats_api_secret
|
||||
fi
|
||||
|
||||
# --- Step 1: Choose Environment ---
|
||||
echo "Please select an option:"
|
||||
echo "1) Run using Pre-built Image (Recommended)"
|
||||
@@ -18,7 +124,14 @@ read -r -p "Enter choice [1-2]: " choice
|
||||
case "$choice" in
|
||||
1)
|
||||
echo "--- Running with Pre-built Image ---"
|
||||
if [[ "${WITH_USAGE}" == "true" ]]; then
|
||||
export_stats
|
||||
fi
|
||||
docker compose up -d --remove-orphans --no-build
|
||||
if [[ "${WITH_USAGE}" == "true" ]]; then
|
||||
wait_for_service
|
||||
import_stats
|
||||
fi
|
||||
echo "Services are starting from remote image."
|
||||
echo "Run 'docker compose logs -f' to see the logs."
|
||||
;;
|
||||
@@ -38,16 +151,25 @@ case "$choice" in
|
||||
|
||||
# Build and start the services with a local-only image tag
|
||||
export CLI_PROXY_IMAGE="cli-proxy-api:local"
|
||||
|
||||
|
||||
echo "Building the Docker image..."
|
||||
docker compose build \
|
||||
--build-arg VERSION="${VERSION}" \
|
||||
--build-arg COMMIT="${COMMIT}" \
|
||||
--build-arg BUILD_DATE="${BUILD_DATE}"
|
||||
|
||||
if [[ "${WITH_USAGE}" == "true" ]]; then
|
||||
export_stats
|
||||
fi
|
||||
|
||||
echo "Starting the services..."
|
||||
docker compose up -d --remove-orphans --pull never
|
||||
|
||||
if [[ "${WITH_USAGE}" == "true" ]]; then
|
||||
wait_for_service
|
||||
import_stats
|
||||
fi
|
||||
|
||||
echo "Build complete. Services are starting."
|
||||
echo "Run 'docker compose logs -f' to see the logs."
|
||||
;;
|
||||
@@ -55,4 +177,4 @@ case "$choice" in
|
||||
echo "Invalid choice. Please enter 1 or 2."
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
esac
|
||||
|
||||
@@ -22,7 +22,7 @@ services:
|
||||
- "51121:51121"
|
||||
- "11451:11451"
|
||||
volumes:
|
||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
||||
- ./auths:/root/.cli-proxy-api
|
||||
- ./logs:/CLIProxyAPI/logs
|
||||
- ${CLI_PROXY_CONFIG_PATH:-./config.yaml}:/CLIProxyAPI/config.yaml
|
||||
- ${CLI_PROXY_AUTH_PATH:-./auths}:/root/.cli-proxy-api
|
||||
- ${CLI_PROXY_LOG_PATH:-./logs}:/CLIProxyAPI/logs
|
||||
restart: unless-stopped
|
||||
|
||||
@@ -1,443 +0,0 @@
|
||||
# Amp CLI Integration Guide
|
||||
|
||||
This guide explains how to use CLIProxyAPI with Amp CLI and Amp IDE extensions, enabling you to use your existing Google/ChatGPT/Claude subscriptions (via OAuth) with Amp's CLI.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Which Providers Should You Authenticate?](#which-providers-should-you-authenticate)
|
||||
- [Architecture](#architecture)
|
||||
- [Configuration](#configuration)
|
||||
- [Model Mapping Configuration](#model-mapping-configuration)
|
||||
- [Setup](#setup)
|
||||
- [Usage](#usage)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
## Overview
|
||||
|
||||
The Amp CLI integration adds specialized routing to support Amp's API patterns while maintaining full compatibility with all existing CLIProxyAPI features. This allows you to use both traditional CLIProxyAPI features and Amp CLI with the same proxy server.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Provider route aliases**: Maps Amp's `/api/provider/{provider}/v1...` patterns to CLIProxyAPI handlers
|
||||
- **Management proxy**: Forwards OAuth and account management requests to Amp's control plane
|
||||
- **Smart fallback**: Automatically routes unconfigured models to ampcode.com
|
||||
- **Model mapping**: Route unavailable models to alternatives you have access to (e.g., `claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- **Secret management**: Configurable precedence (config > env > file) with 5-minute caching
|
||||
- **Security-first**: Management routes restricted to localhost by default
|
||||
- **Automatic gzip handling**: Decompresses responses from Amp upstream
|
||||
|
||||
### What You Can Do
|
||||
|
||||
- Use Amp CLI with your Google account (Gemini 3 Pro Preview, Gemini 2.5 Pro, Gemini 2.5 Flash)
|
||||
- Use Amp CLI with your ChatGPT Plus/Pro subscription (GPT-5, GPT-5 Codex models)
|
||||
- Use Amp CLI with your Claude Pro/Max subscription (Claude Sonnet 4.5, Opus 4.1)
|
||||
- Use Amp IDE extensions (VS Code, Cursor, Windsurf, etc.) with the same proxy
|
||||
- Run multiple CLI tools (Factory + Amp) through one proxy server
|
||||
- Route unconfigured models automatically through ampcode.com
|
||||
|
||||
### Which Providers Should You Authenticate?
|
||||
|
||||
**Important**: The providers you need to authenticate depend on which models and features your installed version of Amp currently uses. Amp employs different providers for various agent modes and specialized subagents:
|
||||
|
||||
- **Smart mode**: Uses Google/Gemini models (Gemini 3 Pro)
|
||||
- **Rush mode**: Uses Anthropic/Claude models (Claude Haiku 4.5)
|
||||
- **Oracle subagent**: Uses OpenAI/GPT models (GPT-5 medium reasoning)
|
||||
- **Librarian subagent**: Uses Anthropic/Claude models (Claude Sonnet 4.5)
|
||||
- **Search subagent**: Uses Anthropic/Claude models (Claude Haiku 4.5)
|
||||
- **Review feature**: Uses Google/Gemini models (Gemini 2.5 Flash-Lite)
|
||||
|
||||
For the most current information about which models Amp uses, see the **[Amp Models Documentation](https://ampcode.com/models)**.
|
||||
|
||||
#### Fallback Behavior
|
||||
|
||||
CLIProxyAPI uses a smart fallback system:
|
||||
|
||||
1. **Provider authenticated locally** (`--login`, `--codex-login`, `--claude-login`):
|
||||
- Requests use **your OAuth subscription** (ChatGPT Plus/Pro, Claude Pro/Max, Google account)
|
||||
- You benefit from your subscription's included usage quotas
|
||||
- No Amp credits consumed
|
||||
|
||||
2. **Provider NOT authenticated locally**:
|
||||
- Requests automatically forward to **ampcode.com**
|
||||
- Uses Amp's backend provider connections
|
||||
- **Requires Amp credits** if the provider is paid (OpenAI, Anthropic paid tiers)
|
||||
- May result in errors if Amp credit balance is insufficient
|
||||
|
||||
**Recommendation**: Authenticate all providers you have subscriptions for to maximize value and minimize Amp credit usage. If you don't have subscriptions to all providers Amp uses, ensure you have sufficient Amp credits available for fallback requests.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Request Flow
|
||||
|
||||
```
|
||||
Amp CLI/IDE
|
||||
↓
|
||||
├─ Provider API requests (/api/provider/{provider}/v1/...)
|
||||
│ ↓
|
||||
│ ├─ Model configured locally?
|
||||
│ │ YES → Use local OAuth tokens (OpenAI/Claude/Gemini handlers)
|
||||
│ │ NO ↓
|
||||
│ │ ├─ Model mapping configured?
|
||||
│ │ │ YES → Rewrite model → Use local handler (free)
|
||||
│ │ │ NO → Forward to ampcode.com (uses Amp credits)
|
||||
│ ↓
|
||||
│ Response
|
||||
│
|
||||
└─ Management requests (/api/auth, /api/user, /api/threads, ...)
|
||||
↓
|
||||
├─ Localhost check (security)
|
||||
↓
|
||||
└─ Reverse proxy to ampcode.com
|
||||
↓
|
||||
Response (auto-decompressed if gzipped)
|
||||
```
|
||||
|
||||
### Components
|
||||
|
||||
The Amp integration is implemented as a modular routing module (`internal/api/modules/amp/`) with these components:
|
||||
|
||||
1. **Route Aliases** (`routes.go`): Maps Amp-style paths to standard handlers
|
||||
2. **Reverse Proxy** (`proxy.go`): Forwards management requests to ampcode.com
|
||||
3. **Fallback Handler** (`fallback_handlers.go`): Routes unconfigured models to ampcode.com
|
||||
4. **Secret Management** (`secret.go`): Multi-source API key resolution with caching
|
||||
5. **Main Module** (`amp.go`): Orchestrates registration and configuration
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
Add these fields to your `config.yaml`:
|
||||
|
||||
```yaml
|
||||
# Amp upstream control plane (required for management routes)
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
|
||||
# Optional: Override API key (otherwise uses env or file)
|
||||
# amp-upstream-api-key: "your-amp-api-key"
|
||||
|
||||
# Security: restrict management routes to localhost (recommended)
|
||||
amp-restrict-management-to-localhost: true
|
||||
```
|
||||
|
||||
### Model Mapping Configuration
|
||||
|
||||
When Amp CLI requests a model that you don't have access to, you can configure mappings to route those requests to alternative models that you DO have available. This avoids consuming Amp credits for models you could handle locally.
|
||||
|
||||
```yaml
|
||||
# Route unavailable models to alternatives
|
||||
amp-model-mappings:
|
||||
# Example: Route Claude Opus 4.5 requests to Claude Sonnet 4
|
||||
- from: "claude-opus-4.5"
|
||||
to: "claude-sonnet-4"
|
||||
|
||||
# Example: Route GPT-5 requests to Gemini 2.5 Pro
|
||||
- from: "gpt-5"
|
||||
to: "gemini-2.5-pro"
|
||||
|
||||
# Example: Map older model names to newer versions
|
||||
- from: "claude-3-opus-20240229"
|
||||
to: "claude-3-5-sonnet-20241022"
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
|
||||
1. Amp CLI requests a model (e.g., `claude-opus-4.5`)
|
||||
2. CLIProxyAPI checks if a local provider is available for that model
|
||||
3. If not available, it checks the model mappings
|
||||
4. If a mapping exists, the request is rewritten to use the target model
|
||||
5. The request is then handled locally (free, using your OAuth subscription)
|
||||
|
||||
**Benefits:**
|
||||
- **Save Amp credits**: Use your local subscriptions instead of forwarding to ampcode.com
|
||||
- **Hot-reload**: Mappings can be updated without restarting the proxy
|
||||
- **Structured logging**: Clear logs show when mappings are applied
|
||||
|
||||
**Routing Decision Logs:**
|
||||
|
||||
The proxy logs each routing decision with structured fields:
|
||||
|
||||
```
|
||||
[AMP] Using local provider for model: gemini-2.5-pro # Local provider (free)
|
||||
[AMP] Model mapped: claude-opus-4.5 -> claude-sonnet-4 # Mapping applied (free)
|
||||
[AMP] Forwarding to ampcode.com (uses Amp credits) - model_id: gpt-5 # Fallback (costs credits)
|
||||
```
|
||||
|
||||
### Secret Resolution Precedence
|
||||
|
||||
The Amp module resolves API keys using this precedence order:
|
||||
|
||||
| Source | Key | Priority | Cache |
|
||||
|--------|-----|----------|-------|
|
||||
| Config file | `amp-upstream-api-key` | High | No |
|
||||
| Environment | `AMP_API_KEY` | Medium | No |
|
||||
| Amp secrets file | `~/.local/share/amp/secrets.json` | Low | 5 min |
|
||||
|
||||
**Recommendation**: Use the Amp secrets file (lowest precedence) for normal usage. This file is automatically managed by `amp login`.
|
||||
|
||||
### Security Settings
|
||||
|
||||
**`amp-restrict-management-to-localhost`** (default: `true`)
|
||||
|
||||
When enabled, management routes (`/api/auth`, `/api/user`, `/api/threads`, etc.) only accept connections from localhost (127.0.0.1, ::1). This prevents:
|
||||
- Drive-by browser attacks
|
||||
- Remote access to management endpoints
|
||||
- CORS-based attacks
|
||||
- Header spoofing attacks (e.g., `X-Forwarded-For: 127.0.0.1`)
|
||||
|
||||
#### How It Works
|
||||
|
||||
This restriction uses the **actual TCP connection address** (`RemoteAddr`), not HTTP headers like `X-Forwarded-For`. This prevents header spoofing attacks but has important implications:
|
||||
|
||||
- ✅ **Works for direct connections**: Running CLIProxyAPI directly on your machine or server
|
||||
- ⚠️ **May not work behind reverse proxies**: If deploying behind nginx, Cloudflare, or other proxies, the connection will appear to come from the proxy's IP, not localhost
|
||||
|
||||
#### Reverse Proxy Deployments
|
||||
|
||||
If you need to run CLIProxyAPI behind a reverse proxy (nginx, Caddy, Cloudflare Tunnel, etc.):
|
||||
|
||||
1. **Disable the localhost restriction**:
|
||||
```yaml
|
||||
amp-restrict-management-to-localhost: false
|
||||
```
|
||||
|
||||
2. **Use alternative security measures**:
|
||||
- Firewall rules restricting access to management routes
|
||||
- Proxy-level authentication (HTTP Basic Auth, OAuth)
|
||||
- Network-level isolation (VPN, Tailscale, Cloudflare Access)
|
||||
- Bind CLIProxyAPI to `127.0.0.1` only and access via SSH tunnel
|
||||
|
||||
3. **Example nginx configuration** (blocks external access to management routes):
|
||||
```nginx
|
||||
location /api/auth { deny all; }
|
||||
location /api/user { deny all; }
|
||||
location /api/threads { deny all; }
|
||||
location /api/internal { deny all; }
|
||||
```
|
||||
|
||||
**Important**: Only disable `amp-restrict-management-to-localhost` if you understand the security implications and have other protections in place.
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Configure CLIProxyAPI
|
||||
|
||||
Create or edit `config.yaml`:
|
||||
|
||||
```yaml
|
||||
port: 8317
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
# Amp integration
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
amp-restrict-management-to-localhost: true
|
||||
|
||||
# Other standard settings...
|
||||
debug: false
|
||||
logging-to-file: true
|
||||
```
|
||||
|
||||
### 2. Authenticate with Providers
|
||||
|
||||
Run OAuth login for the providers you want to use:
|
||||
|
||||
**Google Account (Gemini 2.5 Pro, Gemini 2.5 Flash, Gemini 3 Pro Preview):**
|
||||
```bash
|
||||
./cli-proxy-api --login
|
||||
```
|
||||
|
||||
**ChatGPT Plus/Pro (GPT-5, GPT-5 Codex):**
|
||||
```bash
|
||||
./cli-proxy-api --codex-login
|
||||
```
|
||||
|
||||
**Claude Pro/Max (Claude Sonnet 4.5, Opus 4.1):**
|
||||
```bash
|
||||
./cli-proxy-api --claude-login
|
||||
```
|
||||
|
||||
Tokens are saved to:
|
||||
- Gemini: `~/.cli-proxy-api/gemini-<email>.json`
|
||||
- OpenAI Codex: `~/.cli-proxy-api/codex-<email>.json`
|
||||
- Claude: `~/.cli-proxy-api/claude-<email>.json`
|
||||
|
||||
### 3. Start the Proxy
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --config config.yaml
|
||||
```
|
||||
|
||||
Or run in background with tmux (recommended for remote servers):
|
||||
|
||||
```bash
|
||||
tmux new-session -d -s proxy "./cli-proxy-api --config config.yaml"
|
||||
```
|
||||
|
||||
### 4. Configure Amp CLI
|
||||
|
||||
#### Option A: Settings File
|
||||
|
||||
Edit `~/.config/amp/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"amp.url": "http://localhost:8317"
|
||||
}
|
||||
```
|
||||
|
||||
#### Option B: Environment Variable
|
||||
|
||||
```bash
|
||||
export AMP_URL=http://localhost:8317
|
||||
```
|
||||
|
||||
### 5. Login and Use Amp
|
||||
|
||||
Login through the proxy (proxied to ampcode.com):
|
||||
|
||||
```bash
|
||||
amp login
|
||||
```
|
||||
|
||||
Use Amp as normal:
|
||||
|
||||
```bash
|
||||
amp "Write a hello world program in Python"
|
||||
```
|
||||
|
||||
### 6. (Optional) Configure Amp IDE Extension
|
||||
|
||||
The proxy also works with Amp IDE extensions for VS Code, Cursor, Windsurf, etc.
|
||||
|
||||
1. Open Amp extension settings in your IDE
|
||||
2. Set **Amp URL** to `http://localhost:8317`
|
||||
3. Login with your Amp account
|
||||
4. Start using Amp in your IDE
|
||||
|
||||
Both CLI and IDE can use the proxy simultaneously.
|
||||
|
||||
## Usage
|
||||
|
||||
### Supported Routes
|
||||
|
||||
#### Provider Aliases (Always Available)
|
||||
|
||||
These routes work even without `amp-upstream-url` configured:
|
||||
|
||||
- `/api/provider/openai/v1/chat/completions`
|
||||
- `/api/provider/openai/v1/responses`
|
||||
- `/api/provider/anthropic/v1/messages`
|
||||
- `/api/provider/google/v1beta/models/:action`
|
||||
|
||||
Amp CLI calls these routes with your OAuth-authenticated models configured in CLIProxyAPI.
|
||||
|
||||
#### Management Routes (Require `amp-upstream-url`)
|
||||
|
||||
These routes are proxied to ampcode.com:
|
||||
|
||||
- `/api/auth` - Authentication
|
||||
- `/api/user` - User profile
|
||||
- `/api/meta` - Metadata
|
||||
- `/api/threads` - Conversation threads
|
||||
- `/api/telemetry` - Usage telemetry
|
||||
- `/api/internal` - Internal APIs
|
||||
|
||||
**Security**: Restricted to localhost by default.
|
||||
|
||||
### Model Fallback Behavior
|
||||
|
||||
When Amp requests a model:
|
||||
|
||||
1. **Check local configuration**: Does CLIProxyAPI have OAuth tokens for this model's provider?
|
||||
2. **If YES**: Route to local handler (use your OAuth subscription)
|
||||
3. **If NO**: Check if a model mapping exists
|
||||
4. **If mapping exists**: Rewrite request to mapped model → Route to local handler (free)
|
||||
5. **If no mapping**: Forward to ampcode.com (uses Amp credits)
|
||||
|
||||
This enables seamless mixed usage:
|
||||
- Models you've configured (Gemini, ChatGPT, Claude) → Your OAuth subscriptions
|
||||
- Models with mappings configured → Routed to alternative local models (free)
|
||||
- Models you haven't configured and have no mapping → Amp's default providers (uses credits)
|
||||
|
||||
### Example API Calls
|
||||
|
||||
**Chat completion with local OAuth:**
|
||||
```bash
|
||||
curl http://localhost:8317/api/provider/openai/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-5",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}'
|
||||
```
|
||||
|
||||
**Management endpoint (localhost only):**
|
||||
```bash
|
||||
curl http://localhost:8317/api/user
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
|---------|--------------|-----|
|
||||
| 404 on `/api/provider/...` | Incorrect route path | Ensure exact path: `/api/provider/{provider}/v1...` |
|
||||
| 403 on `/api/user` | Non-localhost request | Run from same machine or disable `amp-restrict-management-to-localhost` (not recommended) |
|
||||
| 401/403 from provider | Missing/expired OAuth | Re-run `--codex-login` or `--claude-login` |
|
||||
| Amp gzip errors | Response decompression issue | Update to latest build; auto-decompression should handle this |
|
||||
| Models not using proxy | Wrong Amp URL | Verify `amp.url` setting or `AMP_URL` environment variable |
|
||||
| CORS errors | Protected management endpoint | Use CLI/terminal, not browser |
|
||||
|
||||
### Diagnostics
|
||||
|
||||
**Check proxy logs:**
|
||||
```bash
|
||||
# If logging-to-file: true
|
||||
tail -f logs/requests.log
|
||||
|
||||
# If running in tmux
|
||||
tmux attach-session -t proxy
|
||||
```
|
||||
|
||||
**Enable debug mode** (temporarily):
|
||||
```yaml
|
||||
debug: true
|
||||
```
|
||||
|
||||
**Test basic connectivity:**
|
||||
```bash
|
||||
# Check if proxy is running
|
||||
curl http://localhost:8317/v1/models
|
||||
|
||||
# Check Amp-specific route
|
||||
curl http://localhost:8317/api/provider/openai/v1/models
|
||||
```
|
||||
|
||||
**Verify Amp configuration:**
|
||||
```bash
|
||||
# Check if Amp is using proxy
|
||||
amp config get amp.url
|
||||
|
||||
# Or check environment
|
||||
echo $AMP_URL
|
||||
```
|
||||
|
||||
### Security Checklist
|
||||
|
||||
- ✅ Keep `amp-restrict-management-to-localhost: true` (default)
|
||||
- ✅ Don't expose proxy publicly (bind to localhost or use firewall/VPN)
|
||||
- ✅ Use the Amp secrets file (`~/.local/share/amp/secrets.json`) managed by `amp login`
|
||||
- ✅ Rotate OAuth tokens periodically by re-running login commands
|
||||
- ✅ Store config and auth-dir on encrypted disk if handling sensitive data
|
||||
- ✅ Keep proxy binary up to date for security fixes
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [CLIProxyAPI Main Documentation](https://help.router-for.me/)
|
||||
- [Amp CLI Official Manual](https://ampcode.com/manual)
|
||||
- [Management API Reference](https://help.router-for.me/management/api)
|
||||
- [SDK Documentation](sdk-usage.md)
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This integration is for personal/educational use. Using reverse proxies or alternate API bases may violate provider Terms of Service. You are solely responsible for how you use this software. Accounts may be rate-limited, locked, or banned. No warranties. Use at your own risk.
|
||||
@@ -1,392 +0,0 @@
|
||||
# Amp CLI 集成指南
|
||||
|
||||
本指南说明如何在 Amp CLI 和 Amp IDE 扩展中使用 CLIProxyAPI,通过 OAuth 让你能够把已有的 Google/ChatGPT/Claude 订阅与 Amp 的 CLI 一起使用。
|
||||
|
||||
## 目录
|
||||
|
||||
- [概述](#概述)
|
||||
- [应该认证哪些服务提供商?](#应该认证哪些服务提供商)
|
||||
- [架构](#架构)
|
||||
- [配置](#配置)
|
||||
- [设置](#设置)
|
||||
- [用法](#用法)
|
||||
- [故障排查](#故障排查)
|
||||
|
||||
## 概述
|
||||
|
||||
Amp CLI 集成为 Amp 的 API 模式添加了专用路由,同时保持与现有 CLIProxyAPI 功能的完全兼容。这样你可以在同一个代理服务器上同时使用传统 CLIProxyAPI 功能和 Amp CLI。
|
||||
|
||||
### 主要特性
|
||||
|
||||
- **提供者路由别名**:将 Amp 的 `/api/provider/{provider}/v1...` 路径映射到 CLIProxyAPI 处理器
|
||||
- **管理代理**:将 OAuth 和账号管理请求转发到 Amp 控制平面
|
||||
- **智能回退**:自动将未配置的模型路由到 ampcode.com
|
||||
- **密钥管理**:可配置优先级(配置 > 环境变量 > 文件),缓存 5 分钟
|
||||
- **安全优先**:管理路由默认限制为 localhost
|
||||
- **自动 gzip 处理**:自动解压来自 Amp 上游的响应
|
||||
|
||||
### 你可以做什么
|
||||
|
||||
- 使用 Amp CLI 搭配你的 Google 账号(Gemini 3 Pro Preview、Gemini 2.5 Pro、Gemini 2.5 Flash)
|
||||
- 使用 Amp CLI 搭配你的 ChatGPT Plus/Pro 订阅(GPT-5、GPT-5 Codex 模型)
|
||||
- 使用 Amp CLI 搭配你的 Claude Pro/Max 订阅(Claude Sonnet 4.5、Opus 4.1)
|
||||
- 将 Amp IDE 扩展(VS Code、Cursor、Windsurf 等)与同一个代理一起使用
|
||||
- 通过一个代理同时运行多个 CLI 工具(Factory + Amp)
|
||||
- 将未配置的模型自动路由到 ampcode.com
|
||||
|
||||
### 应该认证哪些服务提供商?
|
||||
|
||||
**重要**:需要认证的提供商取决于你安装的 Amp 版本当前使用的模型和功能。Amp 的不同智能模式和子代理会使用不同的提供商:
|
||||
|
||||
- **Smart 模式**:使用 Google/Gemini 模型(Gemini 3 Pro)
|
||||
- **Rush 模式**:使用 Anthropic/Claude 模型(Claude Haiku 4.5)
|
||||
- **Oracle 子代理**:使用 OpenAI/GPT 模型(GPT-5 medium reasoning)
|
||||
- **Librarian 子代理**:使用 Anthropic/Claude 模型(Claude Sonnet 4.5)
|
||||
- **Search 子代理**:使用 Anthropic/Claude 模型(Claude Haiku 4.5)
|
||||
- **Review 功能**:使用 Google/Gemini 模型(Gemini 2.5 Flash-Lite)
|
||||
|
||||
有关 Amp 当前使用哪些模型的最新信息,请参阅 **[Amp 模型文档](https://ampcode.com/models)**。
|
||||
|
||||
#### 回退行为
|
||||
|
||||
CLIProxyAPI 采用智能回退机制:
|
||||
|
||||
1. **本地已认证提供商**(`--login`、`--codex-login`、`--claude-login`):
|
||||
- 请求使用**你的 OAuth 订阅**(ChatGPT Plus/Pro、Claude Pro/Max、Google 账号)
|
||||
- 享受订阅自带的额度
|
||||
- 不消耗 Amp 额度
|
||||
|
||||
2. **本地未认证提供商**:
|
||||
- 请求自动转发到 **ampcode.com**
|
||||
- 使用 Amp 的后端提供商连接
|
||||
- 如果提供商是付费的(OpenAI、Anthropic 付费档),**需要消耗 Amp 额度**
|
||||
- 若 Amp 额度不足,可能产生错误
|
||||
|
||||
**建议**:对你有订阅的所有提供商都进行认证,以最大化价值并尽量减少 Amp 额度消耗。如果没有覆盖 Amp 使用的全部提供商,请确保为回退请求准备足够的 Amp 额度。
|
||||
|
||||
## 架构
|
||||
|
||||
### 请求流
|
||||
|
||||
```
|
||||
Amp CLI/IDE
|
||||
↓
|
||||
├─ Provider API requests (/api/provider/{provider}/v1/...)
|
||||
│ ↓
|
||||
│ ├─ Model configured locally?
|
||||
│ │ YES → Use local OAuth tokens (OpenAI/Claude/Gemini handlers)
|
||||
│ │ NO → Forward to ampcode.com (reverse proxy)
|
||||
│ ↓
|
||||
│ Response
|
||||
│
|
||||
└─ Management requests (/api/auth, /api/user, /api/threads, ...)
|
||||
↓
|
||||
├─ Localhost check (security)
|
||||
↓
|
||||
└─ Reverse proxy to ampcode.com
|
||||
↓
|
||||
Response (auto-decompressed if gzipped)
|
||||
```
|
||||
|
||||
### 组件
|
||||
|
||||
Amp 集成以模块化路由模块(`internal/api/modules/amp/`)实现,包含以下组件:
|
||||
|
||||
1. **路由别名**(`routes.go`):将 Amp 风格的路径映射到标准处理器
|
||||
2. **反向代理**(`proxy.go`):将管理请求转发到 ampcode.com
|
||||
3. **回退处理器**(`fallback_handlers.go`):将未配置的模型路由到 ampcode.com
|
||||
4. **密钥管理**(`secret.go`):多来源 API 密钥解析并带缓存
|
||||
5. **主模块**(`amp.go`):负责注册和配置
|
||||
|
||||
## 配置
|
||||
|
||||
### 基础配置
|
||||
|
||||
在 `config.yaml` 中新增以下字段:
|
||||
|
||||
```yaml
|
||||
# Amp 上游控制平面(管理路由必需)
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
|
||||
# 可选:覆盖 API key(否则使用环境变量或文件)
|
||||
# amp-upstream-api-key: "your-amp-api-key"
|
||||
|
||||
# 安全性:将管理路由限制为 localhost(推荐)
|
||||
amp-restrict-management-to-localhost: true
|
||||
```
|
||||
|
||||
### 密钥解析优先级
|
||||
|
||||
Amp 模块以如下优先级解析 API key:
|
||||
|
||||
| 来源 | 键名 | 优先级 | 缓存 |
|
||||
|------|------|--------|------|
|
||||
| 配置文件 | `amp-upstream-api-key` | 高 | 无 |
|
||||
| 环境变量 | `AMP_API_KEY` | 中 | 无 |
|
||||
| Amp 密钥文件 | `~/.local/share/amp/secrets.json` | 低 | 5 分钟 |
|
||||
|
||||
**建议**:日常使用时采用 Amp 密钥文件(最低优先级)。该文件由 `amp login` 自动管理。
|
||||
|
||||
### 安全设置
|
||||
|
||||
**`amp-restrict-management-to-localhost`**(默认:`true`)
|
||||
|
||||
启用后,管理路由(`/api/auth`、`/api/user`、`/api/threads` 等)只接受来自 localhost(127.0.0.1、::1)的连接,可防止:
|
||||
- 浏览器探测式攻击
|
||||
- 对管理端点的远程访问
|
||||
- 基于 CORS 的攻击
|
||||
- 伪造头攻击(例如 `X-Forwarded-For: 127.0.0.1`)
|
||||
|
||||
#### 工作原理
|
||||
|
||||
此限制使用**实际的 TCP 连接地址**(`RemoteAddr`),而非 `X-Forwarded-For` 等 HTTP 头,能防止头部伪造,但有重要影响:
|
||||
|
||||
- ✅ **直接连接可用**:在本机或服务器直接运行 CLIProxyAPI 时适用
|
||||
- ⚠️ **可能不适用于反向代理场景**:部署在 nginx、Cloudflare 等代理后,请求源会显示为代理 IP 而非 localhost
|
||||
|
||||
#### 反向代理部署
|
||||
|
||||
若需要在反向代理(nginx、Caddy、Cloudflare Tunnel 等)后运行 CLIProxyAPI:
|
||||
|
||||
1. **关闭 localhost 限制**:
|
||||
```yaml
|
||||
amp-restrict-management-to-localhost: false
|
||||
```
|
||||
|
||||
2. **使用替代安全措施**:
|
||||
- 防火墙规则限制管理路由访问
|
||||
- 代理层认证(HTTP Basic Auth、OAuth)
|
||||
- 网络隔离(VPN、Tailscale、Cloudflare Access)
|
||||
- 将 CLIProxyAPI 仅绑定 `127.0.0.1`,并通过 SSH 隧道访问
|
||||
|
||||
3. **nginx 示例配置**(阻止外部访问管理路由):
|
||||
```nginx
|
||||
location /api/auth { deny all; }
|
||||
location /api/user { deny all; }
|
||||
location /api/threads { deny all; }
|
||||
location /api/internal { deny all; }
|
||||
```
|
||||
|
||||
**重要**:只有在理解安全影响并已采取其他防护措施时,才关闭 `amp-restrict-management-to-localhost`。
|
||||
|
||||
## 设置
|
||||
|
||||
### 1. 配置 CLIProxyAPI
|
||||
|
||||
创建或编辑 `config.yaml`:
|
||||
|
||||
```yaml
|
||||
port: 8317
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
# Amp 集成
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
amp-restrict-management-to-localhost: true
|
||||
|
||||
# 其他常规设置...
|
||||
debug: false
|
||||
logging-to-file: true
|
||||
```
|
||||
|
||||
### 2. 认证提供商
|
||||
|
||||
为要使用的提供商执行 OAuth 登录:
|
||||
|
||||
**Google 账号(Gemini 2.5 Pro、Gemini 2.5 Flash、Gemini 3 Pro Preview):**
|
||||
```bash
|
||||
./cli-proxy-api --login
|
||||
```
|
||||
|
||||
**ChatGPT Plus/Pro(GPT-5、GPT-5 Codex):**
|
||||
```bash
|
||||
./cli-proxy-api --codex-login
|
||||
```
|
||||
|
||||
**Claude Pro/Max(Claude Sonnet 4.5、Opus 4.1):**
|
||||
```bash
|
||||
./cli-proxy-api --claude-login
|
||||
```
|
||||
|
||||
令牌会保存到:
|
||||
- Gemini: `~/.cli-proxy-api/gemini-<email>.json`
|
||||
- OpenAI Codex: `~/.cli-proxy-api/codex-<email>.json`
|
||||
- Claude: `~/.cli-proxy-api/claude-<email>.json`
|
||||
|
||||
### 3. 启动代理
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --config config.yaml
|
||||
```
|
||||
|
||||
或使用 tmux 在后台运行(推荐用于远程服务器):
|
||||
|
||||
```bash
|
||||
tmux new-session -d -s proxy "./cli-proxy-api --config config.yaml"
|
||||
```
|
||||
|
||||
### 4. 配置 Amp CLI
|
||||
|
||||
#### 方案 A:配置文件
|
||||
|
||||
编辑 `~/.config/amp/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"amp.url": "http://localhost:8317"
|
||||
}
|
||||
```
|
||||
|
||||
#### 方案 B:环境变量
|
||||
|
||||
```bash
|
||||
export AMP_URL=http://localhost:8317
|
||||
```
|
||||
|
||||
### 5. 登录并使用 Amp
|
||||
|
||||
通过代理登录(请求会被代理到 ampcode.com):
|
||||
|
||||
```bash
|
||||
amp login
|
||||
```
|
||||
|
||||
像平常一样使用 Amp:
|
||||
|
||||
```bash
|
||||
amp "Write a hello world program in Python"
|
||||
```
|
||||
|
||||
### 6. (可选)配置 Amp IDE 扩展
|
||||
|
||||
该代理同样适用于 VS Code、Cursor、Windsurf 等 Amp IDE 扩展。
|
||||
|
||||
1. 在 IDE 中打开 Amp 扩展设置
|
||||
2. 将 **Amp URL** 设置为 `http://localhost:8317`
|
||||
3. 用你的 Amp 账号登录
|
||||
4. 在 IDE 中开始使用 Amp
|
||||
|
||||
CLI 和 IDE 可同时使用该代理。
|
||||
|
||||
## 用法
|
||||
|
||||
### 支持的路由
|
||||
|
||||
#### 提供商别名(始终可用)
|
||||
|
||||
这些路由即使未配置 `amp-upstream-url` 也可使用:
|
||||
|
||||
- `/api/provider/openai/v1/chat/completions`
|
||||
- `/api/provider/openai/v1/responses`
|
||||
- `/api/provider/anthropic/v1/messages`
|
||||
- `/api/provider/google/v1beta/models/:action`
|
||||
|
||||
Amp CLI 会使用你在 CLIProxyAPI 中通过 OAuth 认证的模型来调用这些路由。
|
||||
|
||||
#### 管理路由(需要 `amp-upstream-url`)
|
||||
|
||||
这些路由会被代理到 ampcode.com:
|
||||
|
||||
- `/api/auth` - 认证
|
||||
- `/api/user` - 用户资料
|
||||
- `/api/meta` - 元数据
|
||||
- `/api/threads` - 会话线程
|
||||
- `/api/telemetry` - 使用遥测
|
||||
- `/api/internal` - 内部 API
|
||||
|
||||
**安全性**:默认限制为 localhost。
|
||||
|
||||
### 模型回退行为
|
||||
|
||||
当 Amp 请求模型时:
|
||||
|
||||
1. **检查本地配置**:CLIProxyAPI 是否有该模型提供商的 OAuth 令牌?
|
||||
2. **如果有**:路由到本地处理器(使用你的 OAuth 订阅)
|
||||
3. **如果没有**:转发到 ampcode.com(使用 Amp 的默认路由)
|
||||
|
||||
这实现了无缝混用:
|
||||
- 你已配置的模型(Gemini、ChatGPT、Claude)→ 你的 OAuth 订阅
|
||||
- 未配置的模型 → Amp 的默认提供商
|
||||
|
||||
### 示例 API 调用
|
||||
|
||||
**使用本地 OAuth 的聊天补全:**
|
||||
```bash
|
||||
curl http://localhost:8317/api/provider/openai/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-5",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}'
|
||||
```
|
||||
|
||||
**管理端点(仅限 localhost):**
|
||||
```bash
|
||||
curl http://localhost:8317/api/user
|
||||
```
|
||||
|
||||
## 故障排查
|
||||
|
||||
### 常见问题
|
||||
|
||||
| 症状 | 可能原因 | 解决方案 |
|
||||
|------|----------|----------|
|
||||
| `/api/provider/...` 返回 404 | 路径错误 | 确保路径准确:`/api/provider/{provider}/v1...` |
|
||||
| `/api/user` 返回 403 | 非 localhost 请求 | 在同一机器上访问,或关闭 `amp-restrict-management-to-localhost`(不推荐) |
|
||||
| 提供商返回 401/403 | OAuth 缺失或过期 | 重新运行 `--codex-login` 或 `--claude-login` |
|
||||
| Amp gzip 错误 | 响应解压问题 | 更新到最新构建;自动解压应能处理 |
|
||||
| 模型未走代理 | Amp URL 设置错误 | 检查 `amp.url` 设置或 `AMP_URL` 环境变量 |
|
||||
| CORS 错误 | 受保护的管理端点 | 使用 CLI/终端而非浏览器 |
|
||||
|
||||
### 诊断
|
||||
|
||||
**查看代理日志:**
|
||||
```bash
|
||||
# 若 logging-to-file: true
|
||||
tail -f logs/requests.log
|
||||
|
||||
# 若运行在 tmux 中
|
||||
tmux attach-session -t proxy
|
||||
```
|
||||
|
||||
**临时开启调试模式:**
|
||||
```yaml
|
||||
debug: true
|
||||
```
|
||||
|
||||
**测试基础连通性:**
|
||||
```bash
|
||||
# 检查代理是否运行
|
||||
curl http://localhost:8317/v1/models
|
||||
|
||||
# 检查 Amp 特定路由
|
||||
curl http://localhost:8317/api/provider/openai/v1/models
|
||||
```
|
||||
|
||||
**验证 Amp 配置:**
|
||||
```bash
|
||||
# 检查 Amp 是否使用代理
|
||||
amp config get amp.url
|
||||
|
||||
# 或检查环境变量
|
||||
echo $AMP_URL
|
||||
```
|
||||
|
||||
### 安全清单
|
||||
|
||||
- ✅ 保持 `amp-restrict-management-to-localhost: true`(默认)
|
||||
- ✅ 不要将代理暴露到公共网络(绑定到 localhost 或使用防火墙/VPN)
|
||||
- ✅ 使用 `amp login` 管理的 Amp 密钥文件(`~/.local/share/amp/secrets.json`)
|
||||
- ✅ 定期重新登录轮换 OAuth 令牌
|
||||
- ✅ 若处理敏感数据,使用加密磁盘存储配置和 auth-dir
|
||||
- ✅ 保持代理二进制为最新版本以获取安全修复
|
||||
|
||||
## 其他资源
|
||||
|
||||
- [CLIProxyAPI 主文档](https://help.router-for.me/)
|
||||
- [Amp CLI 官方手册](https://ampcode.com/manual)
|
||||
- [管理 API 参考](https://help.router-for.me/management/api)
|
||||
- [SDK 文档](sdk-usage.md)
|
||||
|
||||
## 免责声明
|
||||
|
||||
此集成仅用于个人或教育用途。使用反向代理或替代 API 基址可能违反提供商的服务条款。你需要对自己的使用方式负责。账号可能会被限速、锁定或封禁。软件不附带任何保证,使用风险自负。
|
||||
@@ -7,80 +7,71 @@ The `github.com/router-for-me/CLIProxyAPI/v6/sdk/access` package centralizes inb
|
||||
```go
|
||||
import (
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
```
|
||||
|
||||
Add the module with `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access`.
|
||||
|
||||
## Provider Registry
|
||||
|
||||
Providers are registered globally and then attached to a `Manager` as a snapshot:
|
||||
|
||||
- `RegisterProvider(type, provider)` installs a pre-initialized provider instance.
|
||||
- Registration order is preserved the first time each `type` is seen.
|
||||
- `RegisteredProviders()` returns the providers in that order.
|
||||
|
||||
## Manager Lifecycle
|
||||
|
||||
```go
|
||||
manager := sdkaccess.NewManager()
|
||||
providers, err := sdkaccess.BuildProviders(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manager.SetProviders(providers)
|
||||
manager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
* `NewManager` constructs an empty manager.
|
||||
* `SetProviders` replaces the provider slice using a defensive copy.
|
||||
* `Providers` retrieves a snapshot that can be iterated safely from other goroutines.
|
||||
* `BuildProviders` translates `config.Config` access declarations into runnable providers. When the config omits explicit providers but defines inline API keys, the helper auto-installs the built-in `config-api-key` provider.
|
||||
|
||||
If the manager itself is `nil` or no providers are configured, the call returns `nil, nil`, allowing callers to treat access control as disabled.
|
||||
|
||||
## Authenticating Requests
|
||||
|
||||
```go
|
||||
result, err := manager.Authenticate(ctx, req)
|
||||
result, authErr := manager.Authenticate(ctx, req)
|
||||
switch {
|
||||
case err == nil:
|
||||
case authErr == nil:
|
||||
// Authentication succeeded; result describes the provider and principal.
|
||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeNoCredentials):
|
||||
// No recognizable credentials were supplied.
|
||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeInvalidCredential):
|
||||
// Supplied credentials were present but rejected.
|
||||
default:
|
||||
// Transport-level failure was returned by a provider.
|
||||
// Internal/transport failure was returned by a provider.
|
||||
}
|
||||
```
|
||||
|
||||
`Manager.Authenticate` walks the configured providers in order. It returns on the first success, skips providers that surface `ErrNotHandled`, and tracks whether any provider reported `ErrNoCredentials` or `ErrInvalidCredential` for downstream error reporting.
|
||||
|
||||
If the manager itself is `nil` or no providers are registered, the call returns `nil, nil`, allowing callers to treat access control as disabled without branching on errors.
|
||||
`Manager.Authenticate` walks the configured providers in order. It returns on the first success, skips providers that return `AuthErrorCodeNotHandled`, and aggregates `AuthErrorCodeNoCredentials` / `AuthErrorCodeInvalidCredential` for a final result.
|
||||
|
||||
Each `Result` includes the provider identifier, the resolved principal, and optional metadata (for example, which header carried the credential).
|
||||
|
||||
## Configuration Layout
|
||||
## Built-in `config-api-key` Provider
|
||||
|
||||
The manager expects access providers under the `auth.providers` key inside `config.yaml`:
|
||||
The proxy includes one built-in access provider:
|
||||
|
||||
- `config-api-key`: Validates API keys declared under top-level `api-keys`.
|
||||
- Credential sources: `Authorization: Bearer`, `X-Goog-Api-Key`, `X-Api-Key`, `?key=`, `?auth_token=`
|
||||
- Metadata: `Result.Metadata["source"]` is set to the matched source label.
|
||||
|
||||
In the CLI server and `sdk/cliproxy`, this provider is registered automatically based on the loaded configuration.
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: inline-api
|
||||
type: config-api-key
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
```
|
||||
|
||||
Fields map directly to `config.AccessProvider`: `name` labels the provider, `type` selects the registered factory, `sdk` can name an external module, `api-keys` seeds inline credentials, and `config` passes provider-specific options.
|
||||
## Loading Providers from External Go Modules
|
||||
|
||||
### Loading providers from external SDK modules
|
||||
|
||||
To consume a provider shipped in another Go module, point the `sdk` field at the module path and import it for its registration side effect:
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: partner-auth
|
||||
type: partner-token
|
||||
sdk: github.com/acme/xplatform/sdk/access/providers/partner
|
||||
config:
|
||||
region: us-west-2
|
||||
audience: cli-proxy
|
||||
```
|
||||
To consume a provider shipped in another Go module, import it for its registration side effect:
|
||||
|
||||
```go
|
||||
import (
|
||||
@@ -89,19 +80,11 @@ import (
|
||||
)
|
||||
```
|
||||
|
||||
The blank identifier import ensures `init` runs so `sdkaccess.RegisterProvider` executes before `BuildProviders` is called.
|
||||
|
||||
## Built-in Providers
|
||||
|
||||
The SDK ships with one provider out of the box:
|
||||
|
||||
- `config-api-key`: Validates API keys declared inline or under top-level `api-keys`. It accepts the key from `Authorization: Bearer`, `X-Goog-Api-Key`, `X-Api-Key`, or the `?key=` query string and reports `ErrInvalidCredential` when no match is found.
|
||||
|
||||
Additional providers can be delivered by third-party packages. When a provider package is imported, it registers itself with `sdkaccess.RegisterProvider`.
|
||||
The blank identifier import ensures `init` runs so `sdkaccess.RegisterProvider` executes before you call `RegisteredProviders()` (or before `cliproxy.NewBuilder().Build()`).
|
||||
|
||||
### Metadata and auditing
|
||||
|
||||
`Result.Metadata` carries provider-specific context. The built-in `config-api-key` provider, for example, stores the credential source (`authorization`, `x-goog-api-key`, `x-api-key`, or `query-key`). Populate this map in custom providers to enrich logs and downstream auditing.
|
||||
`Result.Metadata` carries provider-specific context. The built-in `config-api-key` provider, for example, stores the credential source (`authorization`, `x-goog-api-key`, `x-api-key`, `query-key`, `query-auth-token`). Populate this map in custom providers to enrich logs and downstream auditing.
|
||||
|
||||
## Writing Custom Providers
|
||||
|
||||
@@ -110,13 +93,13 @@ type customProvider struct{}
|
||||
|
||||
func (p *customProvider) Identifier() string { return "my-provider" }
|
||||
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||
token := r.Header.Get("X-Custom")
|
||||
if token == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
if token != "expected" {
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
return nil, sdkaccess.NewInvalidCredentialError()
|
||||
}
|
||||
return &sdkaccess.Result{
|
||||
Provider: p.Identifier(),
|
||||
@@ -126,51 +109,46 @@ func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sd
|
||||
}
|
||||
|
||||
func init() {
|
||||
sdkaccess.RegisterProvider("custom", func(cfg *config.AccessProvider, root *config.Config) (sdkaccess.Provider, error) {
|
||||
return &customProvider{}, nil
|
||||
})
|
||||
sdkaccess.RegisterProvider("custom", &customProvider{})
|
||||
}
|
||||
```
|
||||
|
||||
A provider must implement `Identifier()` and `Authenticate()`. To expose it to configuration, call `RegisterProvider` inside `init`. Provider factories receive the specific `AccessProvider` block plus the full root configuration for contextual needs.
|
||||
A provider must implement `Identifier()` and `Authenticate()`. To make it available to the access manager, call `RegisterProvider` inside `init` with an initialized provider instance.
|
||||
|
||||
## Error Semantics
|
||||
|
||||
- `ErrNoCredentials`: no credentials were present or recognized by any provider.
|
||||
- `ErrInvalidCredential`: at least one provider processed the credentials but rejected them.
|
||||
- `ErrNotHandled`: instructs the manager to fall through to the next provider without affecting aggregate error reporting.
|
||||
- `NewNoCredentialsError()` (`AuthErrorCodeNoCredentials`): no credentials were present or recognized. (HTTP 401)
|
||||
- `NewInvalidCredentialError()` (`AuthErrorCodeInvalidCredential`): credentials were present but rejected. (HTTP 401)
|
||||
- `NewNotHandledError()` (`AuthErrorCodeNotHandled`): fall through to the next provider.
|
||||
- `NewInternalAuthError(message, cause)` (`AuthErrorCodeInternal`): transport/system failure. (HTTP 500)
|
||||
|
||||
Return custom errors to surface transport failures; they propagate immediately to the caller instead of being masked.
|
||||
Errors propagate immediately to the caller unless they are classified as `not_handled` / `no_credentials` / `invalid_credential` and can be aggregated by the manager.
|
||||
|
||||
## Integration with cliproxy Service
|
||||
|
||||
`sdk/cliproxy` wires `@sdk/access` automatically when you build a CLI service via `cliproxy.NewBuilder`. Supplying a preconfigured manager allows you to extend or override the default providers:
|
||||
`sdk/cliproxy` wires `@sdk/access` automatically when you build a CLI service via `cliproxy.NewBuilder`. Supplying a manager lets you reuse the same instance in your host process:
|
||||
|
||||
```go
|
||||
coreCfg, _ := config.LoadConfig("config.yaml")
|
||||
providers, _ := sdkaccess.BuildProviders(coreCfg)
|
||||
manager := sdkaccess.NewManager()
|
||||
manager.SetProviders(providers)
|
||||
accessManager := sdkaccess.NewManager()
|
||||
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(coreCfg).
|
||||
WithAccessManager(manager).
|
||||
WithConfigPath("config.yaml").
|
||||
WithRequestAccessManager(accessManager).
|
||||
Build()
|
||||
```
|
||||
|
||||
The service reuses the manager for every inbound request, ensuring consistent authentication across embedded deployments and the canonical CLI binary.
|
||||
Register any custom providers (typically via blank imports) before calling `Build()` so they are present in the global registry snapshot.
|
||||
|
||||
### Hot reloading providers
|
||||
### Hot reloading
|
||||
|
||||
When configuration changes, rebuild providers and swap them into the manager:
|
||||
When configuration changes, refresh any config-backed providers and then reset the manager's provider chain:
|
||||
|
||||
```go
|
||||
providers, err := sdkaccess.BuildProviders(newCfg)
|
||||
if err != nil {
|
||||
log.Errorf("reload auth providers failed: %v", err)
|
||||
return
|
||||
}
|
||||
accessManager.SetProviders(providers)
|
||||
// configaccess is github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access
|
||||
configaccess.Register(&newCfg.SDKConfig)
|
||||
accessManager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
This mirrors the behaviour in `cliproxy.Service.refreshAccessProviders` and `api.Server.applyAccessConfig`, enabling runtime updates without restarting the process.
|
||||
This mirrors the behaviour in `internal/access.ApplyAccessProviders`, enabling runtime updates without restarting the process.
|
||||
|
||||
@@ -7,80 +7,71 @@
|
||||
```go
|
||||
import (
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
```
|
||||
|
||||
通过 `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access` 添加依赖。
|
||||
|
||||
## Provider Registry
|
||||
|
||||
访问提供者是全局注册,然后以快照形式挂到 `Manager` 上:
|
||||
|
||||
- `RegisterProvider(type, provider)` 注册一个已经初始化好的 provider 实例。
|
||||
- 每个 `type` 第一次出现时会记录其注册顺序。
|
||||
- `RegisteredProviders()` 会按该顺序返回 provider 列表。
|
||||
|
||||
## 管理器生命周期
|
||||
|
||||
```go
|
||||
manager := sdkaccess.NewManager()
|
||||
providers, err := sdkaccess.BuildProviders(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manager.SetProviders(providers)
|
||||
manager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
- `NewManager` 创建空管理器。
|
||||
- `SetProviders` 替换提供者切片并做防御性拷贝。
|
||||
- `Providers` 返回适合并发读取的快照。
|
||||
- `BuildProviders` 将 `config.Config` 中的访问配置转换成可运行的提供者。当配置没有显式声明但包含顶层 `api-keys` 时,会自动挂载内建的 `config-api-key` 提供者。
|
||||
|
||||
如果管理器本身为 `nil` 或未配置任何 provider,调用会返回 `nil, nil`,可视为关闭访问控制。
|
||||
|
||||
## 认证请求
|
||||
|
||||
```go
|
||||
result, err := manager.Authenticate(ctx, req)
|
||||
result, authErr := manager.Authenticate(ctx, req)
|
||||
switch {
|
||||
case err == nil:
|
||||
case authErr == nil:
|
||||
// Authentication succeeded; result carries provider and principal.
|
||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeNoCredentials):
|
||||
// No recognizable credentials were supplied.
|
||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeInvalidCredential):
|
||||
// Credentials were present but rejected.
|
||||
default:
|
||||
// Provider surfaced a transport-level failure.
|
||||
}
|
||||
```
|
||||
|
||||
`Manager.Authenticate` 按配置顺序遍历提供者。遇到成功立即返回,`ErrNotHandled` 会继续尝试下一个;若发现 `ErrNoCredentials` 或 `ErrInvalidCredential`,会在遍历结束后汇总给调用方。
|
||||
|
||||
若管理器本身为 `nil` 或尚未注册提供者,调用会返回 `nil, nil`,让调用方无需针对错误做额外分支即可关闭访问控制。
|
||||
`Manager.Authenticate` 会按顺序遍历 provider:遇到成功立即返回,`AuthErrorCodeNotHandled` 会继续尝试下一个;`AuthErrorCodeNoCredentials` / `AuthErrorCodeInvalidCredential` 会在遍历结束后汇总给调用方。
|
||||
|
||||
`Result` 提供认证提供者标识、解析出的主体以及可选元数据(例如凭证来源)。
|
||||
|
||||
## 配置结构
|
||||
## 内建 `config-api-key` Provider
|
||||
|
||||
在 `config.yaml` 的 `auth.providers` 下定义访问提供者:
|
||||
代理内置一个访问提供者:
|
||||
|
||||
- `config-api-key`:校验 `config.yaml` 顶层的 `api-keys`。
|
||||
- 凭证来源:`Authorization: Bearer`、`X-Goog-Api-Key`、`X-Api-Key`、`?key=`、`?auth_token=`
|
||||
- 元数据:`Result.Metadata["source"]` 会写入匹配到的来源标识
|
||||
|
||||
在 CLI 服务端与 `sdk/cliproxy` 中,该 provider 会根据加载到的配置自动注册。
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: inline-api
|
||||
type: config-api-key
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
```
|
||||
|
||||
条目映射到 `config.AccessProvider`:`name` 指定实例名,`type` 选择注册的工厂,`sdk` 可引用第三方模块,`api-keys` 提供内联凭证,`config` 用于传递特定选项。
|
||||
## 引入外部 Go 模块提供者
|
||||
|
||||
### 引入外部 SDK 提供者
|
||||
|
||||
若要消费其它 Go 模块输出的访问提供者,可在配置里填写 `sdk` 字段并在代码中引入该包,利用其 `init` 注册过程:
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: partner-auth
|
||||
type: partner-token
|
||||
sdk: github.com/acme/xplatform/sdk/access/providers/partner
|
||||
config:
|
||||
region: us-west-2
|
||||
audience: cli-proxy
|
||||
```
|
||||
若要消费其它 Go 模块输出的访问提供者,直接用空白标识符导入以触发其 `init` 注册即可:
|
||||
|
||||
```go
|
||||
import (
|
||||
@@ -89,19 +80,11 @@ import (
|
||||
)
|
||||
```
|
||||
|
||||
通过空白标识符导入即可确保 `init` 调用,先于 `BuildProviders` 完成 `sdkaccess.RegisterProvider`。
|
||||
|
||||
## 内建提供者
|
||||
|
||||
当前 SDK 默认内置:
|
||||
|
||||
- `config-api-key`:校验配置中的 API Key。它从 `Authorization: Bearer`、`X-Goog-Api-Key`、`X-Api-Key` 以及查询参数 `?key=` 提取凭证,不匹配时抛出 `ErrInvalidCredential`。
|
||||
|
||||
导入第三方包即可通过 `sdkaccess.RegisterProvider` 注册更多类型。
|
||||
空白导入可确保 `init` 先执行,从而在你调用 `RegisteredProviders()`(或 `cliproxy.NewBuilder().Build()`)之前完成 `sdkaccess.RegisterProvider`。
|
||||
|
||||
### 元数据与审计
|
||||
|
||||
`Result.Metadata` 用于携带提供者特定的上下文信息。内建的 `config-api-key` 会记录凭证来源(`authorization`、`x-goog-api-key`、`x-api-key` 或 `query-key`)。自定义提供者同样可以填充该 Map,以便丰富日志与审计场景。
|
||||
`Result.Metadata` 用于携带提供者特定的上下文信息。内建的 `config-api-key` 会记录凭证来源(`authorization`、`x-goog-api-key`、`x-api-key`、`query-key`、`query-auth-token`)。自定义提供者同样可以填充该 Map,以便丰富日志与审计场景。
|
||||
|
||||
## 编写自定义提供者
|
||||
|
||||
@@ -110,13 +93,13 @@ type customProvider struct{}
|
||||
|
||||
func (p *customProvider) Identifier() string { return "my-provider" }
|
||||
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||
token := r.Header.Get("X-Custom")
|
||||
if token == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
if token != "expected" {
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
return nil, sdkaccess.NewInvalidCredentialError()
|
||||
}
|
||||
return &sdkaccess.Result{
|
||||
Provider: p.Identifier(),
|
||||
@@ -126,51 +109,46 @@ func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sd
|
||||
}
|
||||
|
||||
func init() {
|
||||
sdkaccess.RegisterProvider("custom", func(cfg *config.AccessProvider, root *config.Config) (sdkaccess.Provider, error) {
|
||||
return &customProvider{}, nil
|
||||
})
|
||||
sdkaccess.RegisterProvider("custom", &customProvider{})
|
||||
}
|
||||
```
|
||||
|
||||
自定义提供者需要实现 `Identifier()` 与 `Authenticate()`。在 `init` 中调用 `RegisterProvider` 暴露给配置层,工厂函数既能读取当前条目,也能访问完整根配置。
|
||||
自定义提供者需要实现 `Identifier()` 与 `Authenticate()`。在 `init` 中用已初始化实例调用 `RegisterProvider` 注册到全局 registry。
|
||||
|
||||
## 错误语义
|
||||
|
||||
- `ErrNoCredentials`:任何提供者都未识别到凭证。
|
||||
- `ErrInvalidCredential`:至少一个提供者处理了凭证但判定无效。
|
||||
- `ErrNotHandled`:告诉管理器跳到下一个提供者,不影响最终错误统计。
|
||||
- `NewNoCredentialsError()`(`AuthErrorCodeNoCredentials`):未提供或未识别到凭证。(HTTP 401)
|
||||
- `NewInvalidCredentialError()`(`AuthErrorCodeInvalidCredential`):凭证存在但校验失败。(HTTP 401)
|
||||
- `NewNotHandledError()`(`AuthErrorCodeNotHandled`):告诉管理器跳到下一个 provider。
|
||||
- `NewInternalAuthError(message, cause)`(`AuthErrorCodeInternal`):网络/系统错误。(HTTP 500)
|
||||
|
||||
自定义错误(例如网络异常)会马上冒泡返回。
|
||||
除可汇总的 `not_handled` / `no_credentials` / `invalid_credential` 外,其它错误会立即冒泡返回。
|
||||
|
||||
## 与 cliproxy 集成
|
||||
|
||||
使用 `sdk/cliproxy` 构建服务时会自动接入 `@sdk/access`。如果需要扩展内置行为,可传入自定义管理器:
|
||||
使用 `sdk/cliproxy` 构建服务时会自动接入 `@sdk/access`。如果希望在宿主进程里复用同一个 `Manager` 实例,可传入自定义管理器:
|
||||
|
||||
```go
|
||||
coreCfg, _ := config.LoadConfig("config.yaml")
|
||||
providers, _ := sdkaccess.BuildProviders(coreCfg)
|
||||
manager := sdkaccess.NewManager()
|
||||
manager.SetProviders(providers)
|
||||
accessManager := sdkaccess.NewManager()
|
||||
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(coreCfg).
|
||||
WithAccessManager(manager).
|
||||
WithConfigPath("config.yaml").
|
||||
WithRequestAccessManager(accessManager).
|
||||
Build()
|
||||
```
|
||||
|
||||
服务会复用该管理器处理每一个入站请求,实现与 CLI 二进制一致的访问控制体验。
|
||||
请在调用 `Build()` 之前完成自定义 provider 的注册(通常通过空白导入触发 `init`),以确保它们被包含在全局 registry 的快照中。
|
||||
|
||||
### 动态热更新提供者
|
||||
|
||||
当配置发生变化时,可以重新构建提供者并替换当前列表:
|
||||
当配置发生变化时,刷新依赖配置的 provider,然后重置 manager 的 provider 链:
|
||||
|
||||
```go
|
||||
providers, err := sdkaccess.BuildProviders(newCfg)
|
||||
if err != nil {
|
||||
log.Errorf("reload auth providers failed: %v", err)
|
||||
return
|
||||
}
|
||||
accessManager.SetProviders(providers)
|
||||
// configaccess is github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access
|
||||
configaccess.Register(&newCfg.SDKConfig)
|
||||
accessManager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
这一流程与 `cliproxy.Service.refreshAccessProviders` 和 `api.Server.applyAccessConfig` 保持一致,避免为更新访问策略而重启进程。
|
||||
这一流程与 `internal/access.ApplyAccessProviders` 保持一致,避免为更新访问策略而重启进程。
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -23,13 +24,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/logging"
|
||||
sdktr "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
@@ -122,7 +123,9 @@ func (MyExecutor) Execute(ctx context.Context, a *coreauth.Auth, req clipexec.Re
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Inject credentials via PrepareRequest hook.
|
||||
_ = (MyExecutor{}).PrepareRequest(httpReq, a)
|
||||
if errPrep := (MyExecutor{}).PrepareRequest(httpReq, a); errPrep != nil {
|
||||
return clipexec.Response{}, errPrep
|
||||
}
|
||||
|
||||
resp, errDo := client.Do(httpReq)
|
||||
if errDo != nil {
|
||||
@@ -130,24 +133,39 @@ func (MyExecutor) Execute(ctx context.Context, a *coreauth.Auth, req clipexec.Re
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
// Best-effort close; log if needed in real projects.
|
||||
fmt.Fprintf(os.Stderr, "close response body error: %v\n", errClose)
|
||||
}
|
||||
}()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return clipexec.Response{Payload: body}, nil
|
||||
}
|
||||
|
||||
func (MyExecutor) HttpRequest(ctx context.Context, a *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("myprov executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if errPrep := (MyExecutor{}).PrepareRequest(httpReq, a); errPrep != nil {
|
||||
return nil, errPrep
|
||||
}
|
||||
client := buildHTTPClient(a)
|
||||
return client.Do(httpReq)
|
||||
}
|
||||
|
||||
func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
|
||||
return clipexec.Response{}, errors.New("count tokens not implemented")
|
||||
}
|
||||
|
||||
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
||||
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (*clipexec.StreamResult, error) {
|
||||
ch := make(chan clipexec.StreamChunk, 1)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")}
|
||||
}()
|
||||
return ch, nil
|
||||
return &clipexec.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
@@ -187,7 +205,7 @@ func main() {
|
||||
// Optional: add a simple middleware + custom request logger
|
||||
api.WithMiddleware(func(c *gin.Context) { c.Header("X-Example", "custom-provider"); c.Next() }),
|
||||
api.WithRequestLoggerFactory(func(cfg *config.Config, cfgPath string) logging.RequestLogger {
|
||||
return logging.NewFileRequestLogger(true, "logs", filepath.Dir(cfgPath))
|
||||
return logging.NewFileRequestLoggerWithOptions(true, "logs", filepath.Dir(cfgPath), cfg.ErrorLogsMaxFiles)
|
||||
}),
|
||||
).
|
||||
WithHooks(hooks).
|
||||
@@ -199,8 +217,8 @@ func main() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if err := svc.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
|
||||
panic(err)
|
||||
if errRun := svc.Run(ctx); errRun != nil && !errors.Is(errRun, context.Canceled) {
|
||||
panic(errRun)
|
||||
}
|
||||
_ = os.Stderr // keep os import used (demo only)
|
||||
_ = time.Second
|
||||
|
||||
140
examples/http-request/main.go
Normal file
140
examples/http-request/main.go
Normal file
@@ -0,0 +1,140 @@
|
||||
// Package main demonstrates how to use coreauth.Manager.HttpRequest/NewHttpRequest
|
||||
// to execute arbitrary HTTP requests with provider credentials injected.
|
||||
//
|
||||
// This example registers a minimal custom executor that injects an Authorization
|
||||
// header from auth.Attributes["api_key"], then performs two requests against
|
||||
// httpbin.org to show the injected headers.
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const providerKey = "echo"
|
||||
|
||||
// EchoExecutor is a minimal provider implementation for demonstration purposes.
|
||||
type EchoExecutor struct{}
|
||||
|
||||
func (EchoExecutor) Identifier() string { return providerKey }
|
||||
|
||||
func (EchoExecutor) PrepareRequest(req *http.Request, auth *coreauth.Auth) error {
|
||||
if req == nil || auth == nil {
|
||||
return nil
|
||||
}
|
||||
if auth.Attributes != nil {
|
||||
if apiKey := strings.TrimSpace(auth.Attributes["api_key"]); apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (EchoExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("echo executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if errPrep := (EchoExecutor{}).PrepareRequest(httpReq, auth); errPrep != nil {
|
||||
return nil, errPrep
|
||||
}
|
||||
return http.DefaultClient.Do(httpReq)
|
||||
}
|
||||
|
||||
func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
|
||||
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
|
||||
}
|
||||
|
||||
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (*clipexec.StreamResult, error) {
|
||||
return nil, errors.New("echo executor: ExecuteStream not implemented")
|
||||
}
|
||||
|
||||
func (EchoExecutor) Refresh(context.Context, *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return nil, errors.New("echo executor: Refresh not implemented")
|
||||
}
|
||||
|
||||
func (EchoExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
|
||||
return clipexec.Response{}, errors.New("echo executor: CountTokens not implemented")
|
||||
}
|
||||
|
||||
func main() {
|
||||
log.SetLevel(log.InfoLevel)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
core := coreauth.NewManager(nil, nil, nil)
|
||||
core.RegisterExecutor(EchoExecutor{})
|
||||
|
||||
auth := &coreauth.Auth{
|
||||
ID: "demo-echo",
|
||||
Provider: providerKey,
|
||||
Attributes: map[string]string{
|
||||
"api_key": "demo-api-key",
|
||||
},
|
||||
}
|
||||
|
||||
// Example 1: Build a prepared request and execute it using your own http.Client.
|
||||
reqPrepared, errReqPrepared := core.NewHttpRequest(
|
||||
ctx,
|
||||
auth,
|
||||
http.MethodGet,
|
||||
"https://httpbin.org/anything",
|
||||
nil,
|
||||
http.Header{"X-Example": []string{"prepared"}},
|
||||
)
|
||||
if errReqPrepared != nil {
|
||||
panic(errReqPrepared)
|
||||
}
|
||||
respPrepared, errDoPrepared := http.DefaultClient.Do(reqPrepared)
|
||||
if errDoPrepared != nil {
|
||||
panic(errDoPrepared)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := respPrepared.Body.Close(); errClose != nil {
|
||||
log.Errorf("close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
bodyPrepared, errReadPrepared := io.ReadAll(respPrepared.Body)
|
||||
if errReadPrepared != nil {
|
||||
panic(errReadPrepared)
|
||||
}
|
||||
fmt.Printf("Prepared request status: %d\n%s\n\n", respPrepared.StatusCode, bodyPrepared)
|
||||
|
||||
// Example 2: Execute a raw request via core.HttpRequest (auto inject + do).
|
||||
rawBody := []byte(`{"hello":"world"}`)
|
||||
rawReq, errRawReq := http.NewRequestWithContext(ctx, http.MethodPost, "https://httpbin.org/anything", bytes.NewReader(rawBody))
|
||||
if errRawReq != nil {
|
||||
panic(errRawReq)
|
||||
}
|
||||
rawReq.Header.Set("Content-Type", "application/json")
|
||||
rawReq.Header.Set("X-Example", "executed")
|
||||
|
||||
respExec, errDoExec := core.HttpRequest(ctx, auth, rawReq)
|
||||
if errDoExec != nil {
|
||||
panic(errDoExec)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := respExec.Body.Close(); errClose != nil {
|
||||
log.Errorf("close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
bodyExec, errReadExec := io.ReadAll(respExec.Body)
|
||||
if errReadExec != nil {
|
||||
panic(errReadExec)
|
||||
}
|
||||
fmt.Printf("Manager HttpRequest status: %d\n%s\n", respExec.StatusCode, bodyExec)
|
||||
}
|
||||
39
go.mod
39
go.mod
@@ -1,10 +1,15 @@
|
||||
module github.com/router-for-me/CLIProxyAPI/v6
|
||||
|
||||
go 1.24.0
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.0.6
|
||||
github.com/atotto/clipboard v0.1.4
|
||||
github.com/charmbracelet/bubbles v1.0.0
|
||||
github.com/charmbracelet/bubbletea v1.3.10
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/fxamacker/cbor/v2 v2.9.0
|
||||
github.com/gin-gonic/gin v1.10.1
|
||||
github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145
|
||||
github.com/google/uuid v1.6.0
|
||||
@@ -13,14 +18,17 @@ require (
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/klauspost/compress v1.17.4
|
||||
github.com/minio/minio-go/v7 v7.0.66
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
|
||||
github.com/refraction-networking/utls v1.8.2
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/tiktoken-go/tokenizer v0.7.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/net v0.46.0
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/sync v0.18.0
|
||||
golang.org/x/term v0.37.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
@@ -29,8 +37,16 @@ require (
|
||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/ProtonMail/go-crypto v1.3.0 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.4.1 // indirect
|
||||
github.com/charmbracelet/x/ansi v0.11.6 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.9.0 // indirect
|
||||
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
|
||||
github.com/cloudflare/circl v1.6.1 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
@@ -38,6 +54,7 @@ require (
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/emirpasic/gods v1.18.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-git/gcfg/v2 v2.0.2 // indirect
|
||||
@@ -54,23 +71,31 @@ require (
|
||||
github.com/kevinburke/ssh_config v1.4.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.19 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
github.com/sergi/go-diff v1.4.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/sync v0.17.0 // indirect
|
||||
golang.org/x/sys v0.37.0 // indirect
|
||||
golang.org/x/text v0.30.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
)
|
||||
|
||||
80
go.sum
80
go.sum
@@ -10,10 +10,34 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
|
||||
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
|
||||
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
|
||||
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
|
||||
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
|
||||
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
|
||||
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
|
||||
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
|
||||
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
|
||||
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
|
||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||
@@ -33,8 +57,12 @@ github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o
|
||||
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
|
||||
github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
@@ -99,8 +127,14 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
||||
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
|
||||
@@ -112,12 +146,24 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
||||
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
|
||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
@@ -126,8 +172,6 @@ github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
|
||||
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA=
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
@@ -157,25 +201,33 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
|
||||
@@ -4,19 +4,28 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
var registerOnce sync.Once
|
||||
|
||||
// Register ensures the config-access provider is available to the access manager.
|
||||
func Register() {
|
||||
registerOnce.Do(func() {
|
||||
sdkaccess.RegisterProvider(sdkconfig.AccessProviderTypeConfigAPIKey, newProvider)
|
||||
})
|
||||
func Register(cfg *sdkconfig.SDKConfig) {
|
||||
if cfg == nil {
|
||||
sdkaccess.UnregisterProvider(sdkaccess.AccessProviderTypeConfigAPIKey)
|
||||
return
|
||||
}
|
||||
|
||||
keys := normalizeKeys(cfg.APIKeys)
|
||||
if len(keys) == 0 {
|
||||
sdkaccess.UnregisterProvider(sdkaccess.AccessProviderTypeConfigAPIKey)
|
||||
return
|
||||
}
|
||||
|
||||
sdkaccess.RegisterProvider(
|
||||
sdkaccess.AccessProviderTypeConfigAPIKey,
|
||||
newProvider(sdkaccess.DefaultAccessProviderName, keys),
|
||||
)
|
||||
}
|
||||
|
||||
type provider struct {
|
||||
@@ -24,34 +33,31 @@ type provider struct {
|
||||
keys map[string]struct{}
|
||||
}
|
||||
|
||||
func newProvider(cfg *sdkconfig.AccessProvider, _ *sdkconfig.SDKConfig) (sdkaccess.Provider, error) {
|
||||
name := cfg.Name
|
||||
if name == "" {
|
||||
name = sdkconfig.DefaultAccessProviderName
|
||||
func newProvider(name string, keys []string) *provider {
|
||||
providerName := strings.TrimSpace(name)
|
||||
if providerName == "" {
|
||||
providerName = sdkaccess.DefaultAccessProviderName
|
||||
}
|
||||
keys := make(map[string]struct{}, len(cfg.APIKeys))
|
||||
for _, key := range cfg.APIKeys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
keys[key] = struct{}{}
|
||||
keySet := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
keySet[key] = struct{}{}
|
||||
}
|
||||
return &provider{name: name, keys: keys}, nil
|
||||
return &provider{name: providerName, keys: keySet}
|
||||
}
|
||||
|
||||
func (p *provider) Identifier() string {
|
||||
if p == nil || p.name == "" {
|
||||
return sdkconfig.DefaultAccessProviderName
|
||||
return sdkaccess.DefaultAccessProviderName
|
||||
}
|
||||
return p.name
|
||||
}
|
||||
|
||||
func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||
if p == nil {
|
||||
return nil, sdkaccess.ErrNotHandled
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
if len(p.keys) == 0 {
|
||||
return nil, sdkaccess.ErrNotHandled
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
|
||||
@@ -63,7 +69,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
||||
queryAuthToken = r.URL.Query().Get("auth_token")
|
||||
}
|
||||
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
return nil, sdkaccess.NewNoCredentialsError()
|
||||
}
|
||||
|
||||
apiKey := extractBearerToken(authHeader)
|
||||
@@ -94,7 +100,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
||||
}
|
||||
}
|
||||
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
return nil, sdkaccess.NewInvalidCredentialError()
|
||||
}
|
||||
|
||||
func extractBearerToken(header string) string {
|
||||
@@ -110,3 +116,26 @@ func extractBearerToken(header string) string {
|
||||
}
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
func normalizeKeys(keys []string) []string {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalized := make([]string, 0, len(keys))
|
||||
seen := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
trimmedKey := strings.TrimSpace(key)
|
||||
if trimmedKey == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[trimmedKey]; exists {
|
||||
continue
|
||||
}
|
||||
seen[trimmedKey] = struct{}{}
|
||||
normalized = append(normalized, trimmedKey)
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return nil
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
sdkConfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -17,26 +17,26 @@ import (
|
||||
// ordered provider slice along with the identifiers of providers that were added, updated, or
|
||||
// removed compared to the previous configuration.
|
||||
func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Provider) (result []sdkaccess.Provider, added, updated, removed []string, err error) {
|
||||
_ = oldCfg
|
||||
if newCfg == nil {
|
||||
return nil, nil, nil, nil, nil
|
||||
}
|
||||
|
||||
result = sdkaccess.RegisteredProviders()
|
||||
|
||||
existingMap := make(map[string]sdkaccess.Provider, len(existing))
|
||||
for _, provider := range existing {
|
||||
if provider == nil {
|
||||
providerID := identifierFromProvider(provider)
|
||||
if providerID == "" {
|
||||
continue
|
||||
}
|
||||
existingMap[provider.Identifier()] = provider
|
||||
existingMap[providerID] = provider
|
||||
}
|
||||
|
||||
oldCfgMap := accessProviderMap(oldCfg)
|
||||
newEntries := collectProviderEntries(newCfg)
|
||||
|
||||
result = make([]sdkaccess.Provider, 0, len(newEntries))
|
||||
finalIDs := make(map[string]struct{}, len(newEntries))
|
||||
finalIDs := make(map[string]struct{}, len(result))
|
||||
|
||||
isInlineProvider := func(id string) bool {
|
||||
return strings.EqualFold(id, sdkConfig.DefaultAccessProviderName)
|
||||
return strings.EqualFold(id, sdkaccess.DefaultAccessProviderName)
|
||||
}
|
||||
appendChange := func(list *[]string, id string) {
|
||||
if isInlineProvider(id) {
|
||||
@@ -45,85 +45,28 @@ func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Prov
|
||||
*list = append(*list, id)
|
||||
}
|
||||
|
||||
for _, providerCfg := range newEntries {
|
||||
key := providerIdentifier(providerCfg)
|
||||
if key == "" {
|
||||
for _, provider := range result {
|
||||
providerID := identifierFromProvider(provider)
|
||||
if providerID == "" {
|
||||
continue
|
||||
}
|
||||
finalIDs[providerID] = struct{}{}
|
||||
|
||||
forceRebuild := strings.EqualFold(strings.TrimSpace(providerCfg.Type), sdkConfig.AccessProviderTypeConfigAPIKey)
|
||||
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
||||
isAliased := oldCfgProvider == providerCfg
|
||||
if !forceRebuild && !isAliased && providerConfigEqual(oldCfgProvider, providerCfg) {
|
||||
if existingProvider, okExisting := existingMap[key]; okExisting {
|
||||
result = append(result, existingProvider)
|
||||
finalIDs[key] = struct{}{}
|
||||
continue
|
||||
}
|
||||
}
|
||||
existingProvider, exists := existingMap[providerID]
|
||||
if !exists {
|
||||
appendChange(&added, providerID)
|
||||
continue
|
||||
}
|
||||
|
||||
provider, buildErr := sdkaccess.BuildProvider(providerCfg, &newCfg.SDKConfig)
|
||||
if buildErr != nil {
|
||||
return nil, nil, nil, nil, buildErr
|
||||
}
|
||||
if _, ok := oldCfgMap[key]; ok {
|
||||
if _, existed := existingMap[key]; existed {
|
||||
appendChange(&updated, key)
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
result = append(result, provider)
|
||||
finalIDs[key] = struct{}{}
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
if inline := sdkConfig.MakeInlineAPIKeyProvider(newCfg.APIKeys); inline != nil {
|
||||
key := providerIdentifier(inline)
|
||||
if key != "" {
|
||||
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
||||
if providerConfigEqual(oldCfgProvider, inline) {
|
||||
if existingProvider, okExisting := existingMap[key]; okExisting {
|
||||
result = append(result, existingProvider)
|
||||
finalIDs[key] = struct{}{}
|
||||
goto inlineDone
|
||||
}
|
||||
}
|
||||
}
|
||||
provider, buildErr := sdkaccess.BuildProvider(inline, &newCfg.SDKConfig)
|
||||
if buildErr != nil {
|
||||
return nil, nil, nil, nil, buildErr
|
||||
}
|
||||
if _, existed := existingMap[key]; existed {
|
||||
appendChange(&updated, key)
|
||||
} else if _, hadOld := oldCfgMap[key]; hadOld {
|
||||
appendChange(&updated, key)
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
result = append(result, provider)
|
||||
finalIDs[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
inlineDone:
|
||||
}
|
||||
|
||||
removedSet := make(map[string]struct{})
|
||||
for id := range existingMap {
|
||||
if _, ok := finalIDs[id]; !ok {
|
||||
if isInlineProvider(id) {
|
||||
continue
|
||||
}
|
||||
removedSet[id] = struct{}{}
|
||||
if !providerInstanceEqual(existingProvider, provider) {
|
||||
appendChange(&updated, providerID)
|
||||
}
|
||||
}
|
||||
|
||||
removed = make([]string, 0, len(removedSet))
|
||||
for id := range removedSet {
|
||||
removed = append(removed, id)
|
||||
for providerID := range existingMap {
|
||||
if _, exists := finalIDs[providerID]; exists {
|
||||
continue
|
||||
}
|
||||
appendChange(&removed, providerID)
|
||||
}
|
||||
|
||||
sort.Strings(added)
|
||||
@@ -142,6 +85,7 @@ func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Con
|
||||
}
|
||||
|
||||
existing := manager.Providers()
|
||||
configaccess.Register(&newCfg.SDKConfig)
|
||||
providers, added, updated, removed, err := ReconcileProviders(oldCfg, newCfg, existing)
|
||||
if err != nil {
|
||||
log.Errorf("failed to reconcile request auth providers: %v", err)
|
||||
@@ -160,111 +104,24 @@ func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Con
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func accessProviderMap(cfg *config.Config) map[string]*sdkConfig.AccessProvider {
|
||||
result := make(map[string]*sdkConfig.AccessProvider)
|
||||
if cfg == nil {
|
||||
return result
|
||||
}
|
||||
for i := range cfg.Access.Providers {
|
||||
providerCfg := &cfg.Access.Providers[i]
|
||||
if providerCfg.Type == "" {
|
||||
continue
|
||||
}
|
||||
key := providerIdentifier(providerCfg)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
result[key] = providerCfg
|
||||
}
|
||||
if len(result) == 0 && len(cfg.APIKeys) > 0 {
|
||||
if provider := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); provider != nil {
|
||||
if key := providerIdentifier(provider); key != "" {
|
||||
result[key] = provider
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func collectProviderEntries(cfg *config.Config) []*sdkConfig.AccessProvider {
|
||||
entries := make([]*sdkConfig.AccessProvider, 0, len(cfg.Access.Providers))
|
||||
for i := range cfg.Access.Providers {
|
||||
providerCfg := &cfg.Access.Providers[i]
|
||||
if providerCfg.Type == "" {
|
||||
continue
|
||||
}
|
||||
if key := providerIdentifier(providerCfg); key != "" {
|
||||
entries = append(entries, providerCfg)
|
||||
}
|
||||
}
|
||||
if len(entries) == 0 && len(cfg.APIKeys) > 0 {
|
||||
if inline := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); inline != nil {
|
||||
entries = append(entries, inline)
|
||||
}
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func providerIdentifier(provider *sdkConfig.AccessProvider) string {
|
||||
func identifierFromProvider(provider sdkaccess.Provider) string {
|
||||
if provider == nil {
|
||||
return ""
|
||||
}
|
||||
if name := strings.TrimSpace(provider.Name); name != "" {
|
||||
return name
|
||||
}
|
||||
typ := strings.TrimSpace(provider.Type)
|
||||
if typ == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.EqualFold(typ, sdkConfig.AccessProviderTypeConfigAPIKey) {
|
||||
return sdkConfig.DefaultAccessProviderName
|
||||
}
|
||||
return typ
|
||||
return strings.TrimSpace(provider.Identifier())
|
||||
}
|
||||
|
||||
func providerConfigEqual(a, b *sdkConfig.AccessProvider) bool {
|
||||
func providerInstanceEqual(a, b sdkaccess.Provider) bool {
|
||||
if a == nil || b == nil {
|
||||
return a == nil && b == nil
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(a.Type), strings.TrimSpace(b.Type)) {
|
||||
if reflect.TypeOf(a) != reflect.TypeOf(b) {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(a.SDK) != strings.TrimSpace(b.SDK) {
|
||||
return false
|
||||
valueA := reflect.ValueOf(a)
|
||||
valueB := reflect.ValueOf(b)
|
||||
if valueA.Kind() == reflect.Pointer && valueB.Kind() == reflect.Pointer {
|
||||
return valueA.Pointer() == valueB.Pointer()
|
||||
}
|
||||
if !stringSetEqual(a.APIKeys, b.APIKeys) {
|
||||
return false
|
||||
}
|
||||
if len(a.Config) != len(b.Config) {
|
||||
return false
|
||||
}
|
||||
if len(a.Config) > 0 && !reflect.DeepEqual(a.Config, b.Config) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func stringSetEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
if len(a) == 0 {
|
||||
return true
|
||||
}
|
||||
seen := make(map[string]int, len(a))
|
||||
for _, val := range a {
|
||||
seen[val]++
|
||||
}
|
||||
for _, val := range b {
|
||||
count := seen[val]
|
||||
if count == 0 {
|
||||
return false
|
||||
}
|
||||
if count == 1 {
|
||||
delete(seen, val)
|
||||
} else {
|
||||
seen[val] = count - 1
|
||||
}
|
||||
}
|
||||
return len(seen) == 0
|
||||
return reflect.DeepEqual(a, b)
|
||||
}
|
||||
|
||||
1187
internal/api/handlers/management/api_tools.go
Normal file
1187
internal/api/handlers/management/api_tools.go
Normal file
File diff suppressed because it is too large
Load Diff
149
internal/api/handlers/management/api_tools_cbor_test.go
Normal file
149
internal/api/handlers/management/api_tools_cbor_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestAPICall_CBOR_Support(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Create a test handler
|
||||
h := &Handler{}
|
||||
|
||||
// Create test request data
|
||||
reqData := apiCallRequest{
|
||||
Method: "GET",
|
||||
URL: "https://httpbin.org/get",
|
||||
Header: map[string]string{
|
||||
"User-Agent": "test-client",
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("JSON request and response", func(t *testing.T) {
|
||||
// Marshal request as JSON
|
||||
jsonData, err := json.Marshal(reqData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal JSON: %v", err)
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req := httptest.NewRequest(http.MethodPost, "/v0/management/api-call", bytes.NewReader(jsonData))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Create response recorder
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Create Gin context
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// Call handler
|
||||
h.APICall(c)
|
||||
|
||||
// Verify response
|
||||
if w.Code != http.StatusOK && w.Code != http.StatusBadGateway {
|
||||
t.Logf("Response status: %d", w.Code)
|
||||
t.Logf("Response body: %s", w.Body.String())
|
||||
}
|
||||
|
||||
// Check content type
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if w.Code == http.StatusOK && !contains(contentType, "application/json") {
|
||||
t.Errorf("Expected JSON response, got: %s", contentType)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CBOR request and response", func(t *testing.T) {
|
||||
// Marshal request as CBOR
|
||||
cborData, err := cbor.Marshal(reqData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal CBOR: %v", err)
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req := httptest.NewRequest(http.MethodPost, "/v0/management/api-call", bytes.NewReader(cborData))
|
||||
req.Header.Set("Content-Type", "application/cbor")
|
||||
|
||||
// Create response recorder
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Create Gin context
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
// Call handler
|
||||
h.APICall(c)
|
||||
|
||||
// Verify response
|
||||
if w.Code != http.StatusOK && w.Code != http.StatusBadGateway {
|
||||
t.Logf("Response status: %d", w.Code)
|
||||
t.Logf("Response body: %s", w.Body.String())
|
||||
}
|
||||
|
||||
// Check content type
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if w.Code == http.StatusOK && !contains(contentType, "application/cbor") {
|
||||
t.Errorf("Expected CBOR response, got: %s", contentType)
|
||||
}
|
||||
|
||||
// Try to decode CBOR response
|
||||
if w.Code == http.StatusOK {
|
||||
var response apiCallResponse
|
||||
if err := cbor.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Errorf("Failed to unmarshal CBOR response: %v", err)
|
||||
} else {
|
||||
t.Logf("CBOR response decoded successfully: status_code=%d", response.StatusCode)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CBOR encoding and decoding consistency", func(t *testing.T) {
|
||||
// Test data
|
||||
testReq := apiCallRequest{
|
||||
Method: "POST",
|
||||
URL: "https://example.com/api",
|
||||
Header: map[string]string{
|
||||
"Authorization": "Bearer $TOKEN$",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
Data: `{"key":"value"}`,
|
||||
}
|
||||
|
||||
// Encode to CBOR
|
||||
cborData, err := cbor.Marshal(testReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal to CBOR: %v", err)
|
||||
}
|
||||
|
||||
// Decode from CBOR
|
||||
var decoded apiCallRequest
|
||||
if err := cbor.Unmarshal(cborData, &decoded); err != nil {
|
||||
t.Fatalf("Failed to unmarshal from CBOR: %v", err)
|
||||
}
|
||||
|
||||
// Verify fields
|
||||
if decoded.Method != testReq.Method {
|
||||
t.Errorf("Method mismatch: got %s, want %s", decoded.Method, testReq.Method)
|
||||
}
|
||||
if decoded.URL != testReq.URL {
|
||||
t.Errorf("URL mismatch: got %s, want %s", decoded.URL, testReq.URL)
|
||||
}
|
||||
if decoded.Data != testReq.Data {
|
||||
t.Errorf("Data mismatch: got %s, want %s", decoded.Data, testReq.Data)
|
||||
}
|
||||
if len(decoded.Header) != len(testReq.Header) {
|
||||
t.Errorf("Header count mismatch: got %d, want %d", len(decoded.Header), len(testReq.Header))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) > 0 && len(substr) > 0 && (s == substr || len(s) >= len(substr) && s[:len(substr)] == substr || bytes.Contains([]byte(s), []byte(substr)))
|
||||
}
|
||||
173
internal/api/handlers/management/api_tools_test.go
Normal file
173
internal/api/handlers/management/api_tools_test.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
type memoryAuthStore struct {
|
||||
mu sync.Mutex
|
||||
items map[string]*coreauth.Auth
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) {
|
||||
_ = ctx
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]*coreauth.Auth, 0, len(s.items))
|
||||
for _, a := range s.items {
|
||||
out = append(out, a.Clone())
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
||||
_ = ctx
|
||||
if auth == nil {
|
||||
return "", nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
if s.items == nil {
|
||||
s.items = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
s.items[auth.ID] = auth.Clone()
|
||||
s.mu.Unlock()
|
||||
return auth.ID, nil
|
||||
}
|
||||
|
||||
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
|
||||
_ = ctx
|
||||
s.mu.Lock()
|
||||
delete(s.items, id)
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
|
||||
var callCount int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
|
||||
t.Fatalf("unexpected content-type: %s", ct)
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
values, err := url.ParseQuery(string(bodyBytes))
|
||||
if err != nil {
|
||||
t.Fatalf("parse form: %v", err)
|
||||
}
|
||||
if values.Get("grant_type") != "refresh_token" {
|
||||
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
|
||||
}
|
||||
if values.Get("refresh_token") != "rt" {
|
||||
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
|
||||
}
|
||||
if values.Get("client_id") != antigravityOAuthClientID {
|
||||
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
|
||||
}
|
||||
if values.Get("client_secret") != antigravityOAuthClientSecret {
|
||||
t.Fatalf("unexpected client_secret")
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": "new-token",
|
||||
"refresh_token": "rt2",
|
||||
"expires_in": int64(3600),
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
originalURL := antigravityOAuthTokenURL
|
||||
antigravityOAuthTokenURL = srv.URL
|
||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
||||
|
||||
store := &memoryAuthStore{}
|
||||
manager := coreauth.NewManager(store, nil, nil)
|
||||
|
||||
auth := &coreauth.Auth{
|
||||
ID: "antigravity-test.json",
|
||||
FileName: "antigravity-test.json",
|
||||
Provider: "antigravity",
|
||||
Metadata: map[string]any{
|
||||
"type": "antigravity",
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "rt",
|
||||
"expires_in": int64(3600),
|
||||
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
|
||||
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
h := &Handler{authManager: manager}
|
||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
||||
}
|
||||
if token != "new-token" {
|
||||
t.Fatalf("expected refreshed token, got %q", token)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Fatalf("expected 1 refresh call, got %d", callCount)
|
||||
}
|
||||
|
||||
updated, ok := manager.GetByID(auth.ID)
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth in manager after update")
|
||||
}
|
||||
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
|
||||
t.Fatalf("expected manager metadata updated, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
|
||||
var callCount int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
originalURL := antigravityOAuthTokenURL
|
||||
antigravityOAuthTokenURL = srv.URL
|
||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
||||
|
||||
auth := &coreauth.Auth{
|
||||
ID: "antigravity-valid.json",
|
||||
FileName: "antigravity-valid.json",
|
||||
Provider: "antigravity",
|
||||
Metadata: map[string]any{
|
||||
"type": "antigravity",
|
||||
"access_token": "ok-token",
|
||||
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
h := &Handler{}
|
||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
||||
}
|
||||
if token != "ok-token" {
|
||||
t.Fatalf("expected existing token, got %q", token)
|
||||
}
|
||||
if callCount != 0 {
|
||||
t.Fatalf("expected no refresh calls, got %d", callCount)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
129
internal/api/handlers/management/auth_files_delete_test.go
Normal file
129
internal/api/handlers/management/auth_files_delete_test.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestDeleteAuthFile_UsesAuthPathFromManager(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
authDir := filepath.Join(tempDir, "auth")
|
||||
externalDir := filepath.Join(tempDir, "external")
|
||||
if errMkdirAuth := os.MkdirAll(authDir, 0o700); errMkdirAuth != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", errMkdirAuth)
|
||||
}
|
||||
if errMkdirExternal := os.MkdirAll(externalDir, 0o700); errMkdirExternal != nil {
|
||||
t.Fatalf("failed to create external dir: %v", errMkdirExternal)
|
||||
}
|
||||
|
||||
fileName := "codex-user@example.com-plus.json"
|
||||
shadowPath := filepath.Join(authDir, fileName)
|
||||
realPath := filepath.Join(externalDir, fileName)
|
||||
if errWriteShadow := os.WriteFile(shadowPath, []byte(`{"type":"codex","email":"shadow@example.com"}`), 0o600); errWriteShadow != nil {
|
||||
t.Fatalf("failed to write shadow file: %v", errWriteShadow)
|
||||
}
|
||||
if errWriteReal := os.WriteFile(realPath, []byte(`{"type":"codex","email":"real@example.com"}`), 0o600); errWriteReal != nil {
|
||||
t.Fatalf("failed to write real file: %v", errWriteReal)
|
||||
}
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
record := &coreauth.Auth{
|
||||
ID: "legacy/" + fileName,
|
||||
FileName: fileName,
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusError,
|
||||
Unavailable: true,
|
||||
Attributes: map[string]string{
|
||||
"path": realPath,
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"type": "codex",
|
||||
"email": "real@example.com",
|
||||
},
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
|
||||
t.Fatalf("failed to register auth record: %v", errRegister)
|
||||
}
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||
h.tokenStore = &memoryAuthStore{}
|
||||
|
||||
deleteRec := httptest.NewRecorder()
|
||||
deleteCtx, _ := gin.CreateTestContext(deleteRec)
|
||||
deleteReq := httptest.NewRequest(http.MethodDelete, "/v0/management/auth-files?name="+url.QueryEscape(fileName), nil)
|
||||
deleteCtx.Request = deleteReq
|
||||
h.DeleteAuthFile(deleteCtx)
|
||||
|
||||
if deleteRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, deleteRec.Code, deleteRec.Body.String())
|
||||
}
|
||||
if _, errStatReal := os.Stat(realPath); !os.IsNotExist(errStatReal) {
|
||||
t.Fatalf("expected managed auth file to be removed, stat err: %v", errStatReal)
|
||||
}
|
||||
if _, errStatShadow := os.Stat(shadowPath); errStatShadow != nil {
|
||||
t.Fatalf("expected shadow auth file to remain, stat err: %v", errStatShadow)
|
||||
}
|
||||
|
||||
listRec := httptest.NewRecorder()
|
||||
listCtx, _ := gin.CreateTestContext(listRec)
|
||||
listReq := httptest.NewRequest(http.MethodGet, "/v0/management/auth-files", nil)
|
||||
listCtx.Request = listReq
|
||||
h.ListAuthFiles(listCtx)
|
||||
|
||||
if listRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected list status %d, got %d with body %s", http.StatusOK, listRec.Code, listRec.Body.String())
|
||||
}
|
||||
var listPayload map[string]any
|
||||
if errUnmarshal := json.Unmarshal(listRec.Body.Bytes(), &listPayload); errUnmarshal != nil {
|
||||
t.Fatalf("failed to decode list payload: %v", errUnmarshal)
|
||||
}
|
||||
filesRaw, ok := listPayload["files"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected files array, payload: %#v", listPayload)
|
||||
}
|
||||
if len(filesRaw) != 0 {
|
||||
t.Fatalf("expected removed auth to be hidden from list, got %d entries", len(filesRaw))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAuthFile_FallbackToAuthDirPath(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
authDir := t.TempDir()
|
||||
fileName := "fallback-user.json"
|
||||
filePath := filepath.Join(authDir, fileName)
|
||||
if errWrite := os.WriteFile(filePath, []byte(`{"type":"codex"}`), 0o600); errWrite != nil {
|
||||
t.Fatalf("failed to write auth file: %v", errWrite)
|
||||
}
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
|
||||
h.tokenStore = &memoryAuthStore{}
|
||||
|
||||
deleteRec := httptest.NewRecorder()
|
||||
deleteCtx, _ := gin.CreateTestContext(deleteRec)
|
||||
deleteReq := httptest.NewRequest(http.MethodDelete, "/v0/management/auth-files?name="+url.QueryEscape(fileName), nil)
|
||||
deleteCtx.Request = deleteReq
|
||||
h.DeleteAuthFile(deleteCtx)
|
||||
|
||||
if deleteRec.Code != http.StatusOK {
|
||||
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, deleteRec.Code, deleteRec.Body.String())
|
||||
}
|
||||
if _, errStat := os.Stat(filePath); !os.IsNotExist(errStat) {
|
||||
t.Fatalf("expected auth file to be removed from auth dir, stat err: %v", errStat)
|
||||
}
|
||||
}
|
||||
@@ -1,23 +1,94 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPIPlus/releases/latest"
|
||||
latestReleaseUserAgent = "CLIProxyAPIPlus"
|
||||
)
|
||||
|
||||
func (h *Handler) GetConfig(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(200, gin.H{})
|
||||
return
|
||||
}
|
||||
cfgCopy := *h.cfg
|
||||
c.JSON(200, &cfgCopy)
|
||||
c.JSON(200, new(*h.cfg))
|
||||
}
|
||||
|
||||
type releaseInfo struct {
|
||||
TagName string `json:"tag_name"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// GetLatestVersion returns the latest release version from GitHub without downloading assets.
|
||||
func (h *Handler) GetLatestVersion(c *gin.Context) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
proxyURL := ""
|
||||
if h != nil && h.cfg != nil {
|
||||
proxyURL = strings.TrimSpace(h.cfg.ProxyURL)
|
||||
}
|
||||
if proxyURL != "" {
|
||||
sdkCfg := &sdkconfig.SDKConfig{ProxyURL: proxyURL}
|
||||
util.SetProxy(sdkCfg, client)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, latestReleaseURL, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "request_create_failed", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
req.Header.Set("User-Agent", latestReleaseUserAgent)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "request_failed", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.WithError(errClose).Debug("failed to close latest version response body")
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "unexpected_status", "message": fmt.Sprintf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))})
|
||||
return
|
||||
}
|
||||
|
||||
var info releaseInfo
|
||||
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "decode_failed", "message": errDecode.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
version := strings.TrimSpace(info.TagName)
|
||||
if version == "" {
|
||||
version = strings.TrimSpace(info.Name)
|
||||
}
|
||||
if version == "" {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "invalid_response", "message": "missing release version"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"latest-version": version})
|
||||
}
|
||||
|
||||
func WriteConfig(path string, data []byte) error {
|
||||
@@ -130,6 +201,46 @@ func (h *Handler) PutLoggingToFile(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.LoggingToFile = v })
|
||||
}
|
||||
|
||||
// LogsMaxTotalSizeMB
|
||||
func (h *Handler) GetLogsMaxTotalSizeMB(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"logs-max-total-size-mb": h.cfg.LogsMaxTotalSizeMB})
|
||||
}
|
||||
func (h *Handler) PutLogsMaxTotalSizeMB(c *gin.Context) {
|
||||
var body struct {
|
||||
Value *int `json:"value"`
|
||||
}
|
||||
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
value := *body.Value
|
||||
if value < 0 {
|
||||
value = 0
|
||||
}
|
||||
h.cfg.LogsMaxTotalSizeMB = value
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// ErrorLogsMaxFiles
|
||||
func (h *Handler) GetErrorLogsMaxFiles(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"error-logs-max-files": h.cfg.ErrorLogsMaxFiles})
|
||||
}
|
||||
func (h *Handler) PutErrorLogsMaxFiles(c *gin.Context) {
|
||||
var body struct {
|
||||
Value *int `json:"value"`
|
||||
}
|
||||
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
value := *body.Value
|
||||
if value < 0 {
|
||||
value = 10
|
||||
}
|
||||
h.cfg.ErrorLogsMaxFiles = value
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// Request log
|
||||
func (h *Handler) GetRequestLog(c *gin.Context) { c.JSON(200, gin.H{"request-log": h.cfg.RequestLog}) }
|
||||
func (h *Handler) PutRequestLog(c *gin.Context) {
|
||||
@@ -160,6 +271,52 @@ func (h *Handler) PutMaxRetryInterval(c *gin.Context) {
|
||||
h.updateIntField(c, func(v int) { h.cfg.MaxRetryInterval = v })
|
||||
}
|
||||
|
||||
// ForceModelPrefix
|
||||
func (h *Handler) GetForceModelPrefix(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"force-model-prefix": h.cfg.ForceModelPrefix})
|
||||
}
|
||||
func (h *Handler) PutForceModelPrefix(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.ForceModelPrefix = v })
|
||||
}
|
||||
|
||||
func normalizeRoutingStrategy(strategy string) (string, bool) {
|
||||
normalized := strings.ToLower(strings.TrimSpace(strategy))
|
||||
switch normalized {
|
||||
case "", "round-robin", "roundrobin", "rr":
|
||||
return "round-robin", true
|
||||
case "fill-first", "fillfirst", "ff":
|
||||
return "fill-first", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
// RoutingStrategy
|
||||
func (h *Handler) GetRoutingStrategy(c *gin.Context) {
|
||||
strategy, ok := normalizeRoutingStrategy(h.cfg.Routing.Strategy)
|
||||
if !ok {
|
||||
c.JSON(200, gin.H{"strategy": strings.TrimSpace(h.cfg.Routing.Strategy)})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"strategy": strategy})
|
||||
}
|
||||
func (h *Handler) PutRoutingStrategy(c *gin.Context) {
|
||||
var body struct {
|
||||
Value *string `json:"value"`
|
||||
}
|
||||
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
normalized, ok := normalizeRoutingStrategy(*body.Value)
|
||||
if !ok {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid strategy"})
|
||||
return
|
||||
}
|
||||
h.cfg.Routing.Strategy = normalized
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// Proxy URL
|
||||
func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) }
|
||||
func (h *Handler) PutProxyURL(c *gin.Context) {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,8 +24,15 @@ import (
|
||||
type attemptInfo struct {
|
||||
count int
|
||||
blockedUntil time.Time
|
||||
lastActivity time.Time // track last activity for cleanup
|
||||
}
|
||||
|
||||
// attemptCleanupInterval controls how often stale IP entries are purged
|
||||
const attemptCleanupInterval = 1 * time.Hour
|
||||
|
||||
// attemptMaxIdleTime controls how long an IP can be idle before cleanup
|
||||
const attemptMaxIdleTime = 2 * time.Hour
|
||||
|
||||
// Handler aggregates config reference, persistence path and helpers.
|
||||
type Handler struct {
|
||||
cfg *config.Config
|
||||
@@ -40,6 +47,7 @@ type Handler struct {
|
||||
allowRemoteOverride bool
|
||||
envSecret string
|
||||
logDir string
|
||||
postAuthHook coreauth.PostAuthHook
|
||||
}
|
||||
|
||||
// NewHandler creates a new management handler instance.
|
||||
@@ -47,7 +55,7 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
|
||||
envSecret, _ := os.LookupEnv("MANAGEMENT_PASSWORD")
|
||||
envSecret = strings.TrimSpace(envSecret)
|
||||
|
||||
return &Handler{
|
||||
h := &Handler{
|
||||
cfg: cfg,
|
||||
configFilePath: configFilePath,
|
||||
failedAttempts: make(map[string]*attemptInfo),
|
||||
@@ -57,6 +65,43 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
|
||||
allowRemoteOverride: envSecret != "",
|
||||
envSecret: envSecret,
|
||||
}
|
||||
h.startAttemptCleanup()
|
||||
return h
|
||||
}
|
||||
|
||||
// startAttemptCleanup launches a background goroutine that periodically
|
||||
// removes stale IP entries from failedAttempts to prevent memory leaks.
|
||||
func (h *Handler) startAttemptCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(attemptCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
h.purgeStaleAttempts()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// purgeStaleAttempts removes IP entries that have been idle beyond attemptMaxIdleTime
|
||||
// and whose ban (if any) has expired.
|
||||
func (h *Handler) purgeStaleAttempts() {
|
||||
now := time.Now()
|
||||
h.attemptsMu.Lock()
|
||||
defer h.attemptsMu.Unlock()
|
||||
for ip, ai := range h.failedAttempts {
|
||||
// Skip if still banned
|
||||
if !ai.blockedUntil.IsZero() && now.Before(ai.blockedUntil) {
|
||||
continue
|
||||
}
|
||||
// Remove if idle too long
|
||||
if now.Sub(ai.lastActivity) > attemptMaxIdleTime {
|
||||
delete(h.failedAttempts, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewHandler creates a new management handler instance.
|
||||
func NewHandlerWithoutConfigFilePath(cfg *config.Config, manager *coreauth.Manager) *Handler {
|
||||
return NewHandler(cfg, "", manager)
|
||||
}
|
||||
|
||||
// SetConfig updates the in-memory config reference when the server hot-reloads.
|
||||
@@ -84,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
|
||||
h.logDir = dir
|
||||
}
|
||||
|
||||
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
|
||||
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
|
||||
h.postAuthHook = hook
|
||||
}
|
||||
|
||||
// Middleware enforces access control for management endpoints.
|
||||
// All requests (local and remote) require a valid management key.
|
||||
// Additionally, remote access requires allow-remote-management=true.
|
||||
@@ -144,6 +194,7 @@ func (h *Handler) Middleware() gin.HandlerFunc {
|
||||
h.failedAttempts[clientIP] = aip
|
||||
}
|
||||
aip.count++
|
||||
aip.lastActivity = time.Now()
|
||||
if aip.count >= maxFailures {
|
||||
aip.blockedUntil = time.Now().Add(banDuration)
|
||||
aip.count = 0
|
||||
@@ -240,16 +291,6 @@ func (h *Handler) updateBoolField(c *gin.Context, set func(bool)) {
|
||||
Value *bool `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
var m map[string]any
|
||||
if err2 := c.ShouldBindJSON(&m); err2 == nil {
|
||||
for _, v := range m {
|
||||
if b, ok := v.(bool); ok {
|
||||
set(b)
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -209,6 +209,94 @@ func (h *Handler) GetRequestErrorLogs(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"files": files})
|
||||
}
|
||||
|
||||
// GetRequestLogByID finds and downloads a request log file by its request ID.
|
||||
// The ID is matched against the suffix of log file names (format: *-{requestID}.log).
|
||||
func (h *Handler) GetRequestLogByID(c *gin.Context) {
|
||||
if h == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
||||
return
|
||||
}
|
||||
if h.cfg == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
dir := h.logDirectory()
|
||||
if strings.TrimSpace(dir) == "" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
requestID := strings.TrimSpace(c.Param("id"))
|
||||
if requestID == "" {
|
||||
requestID = strings.TrimSpace(c.Query("id"))
|
||||
}
|
||||
if requestID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing request ID"})
|
||||
return
|
||||
}
|
||||
if strings.ContainsAny(requestID, "/\\") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
suffix := "-" + requestID + ".log"
|
||||
var matchedFile string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if strings.HasSuffix(name, suffix) {
|
||||
matchedFile = name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if matchedFile == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found for the given request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
dirAbs, errAbs := filepath.Abs(dir)
|
||||
if errAbs != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
|
||||
return
|
||||
}
|
||||
fullPath := filepath.Clean(filepath.Join(dirAbs, matchedFile))
|
||||
prefix := dirAbs + string(os.PathSeparator)
|
||||
if !strings.HasPrefix(fullPath, prefix) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
|
||||
return
|
||||
}
|
||||
|
||||
info, errStat := os.Stat(fullPath)
|
||||
if errStat != nil {
|
||||
if os.IsNotExist(errStat) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
|
||||
return
|
||||
}
|
||||
if info.IsDir() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
|
||||
return
|
||||
}
|
||||
|
||||
c.FileAttachment(fullPath, matchedFile)
|
||||
}
|
||||
|
||||
// DownloadRequestErrorLog downloads a specific error request log file by name.
|
||||
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
|
||||
if h == nil {
|
||||
@@ -272,16 +360,7 @@ func (h *Handler) logDirectory() string {
|
||||
if h.logDir != "" {
|
||||
return h.logDir
|
||||
}
|
||||
if base := util.WritablePath(); base != "" {
|
||||
return filepath.Join(base, "logs")
|
||||
}
|
||||
if h.configFilePath != "" {
|
||||
dir := filepath.Dir(h.configFilePath)
|
||||
if dir != "" && dir != "." {
|
||||
return filepath.Join(dir, "logs")
|
||||
}
|
||||
}
|
||||
return "logs"
|
||||
return logging.ResolveLogDirectory(h.cfg)
|
||||
}
|
||||
|
||||
func (h *Handler) collectLogFiles(dir string) ([]string, error) {
|
||||
|
||||
33
internal/api/handlers/management/model_definitions.go
Normal file
33
internal/api/handlers/management/model_definitions.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
// GetStaticModelDefinitions returns static model metadata for a given channel.
|
||||
// Channel is provided via path param (:channel) or query param (?channel=...).
|
||||
func (h *Handler) GetStaticModelDefinitions(c *gin.Context) {
|
||||
channel := strings.TrimSpace(c.Param("channel"))
|
||||
if channel == "" {
|
||||
channel = strings.TrimSpace(c.Query("channel"))
|
||||
}
|
||||
if channel == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "channel is required"})
|
||||
return
|
||||
}
|
||||
|
||||
models := registry.GetStaticModelDefinitionsByChannel(channel)
|
||||
if models == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "unknown channel", "channel": channel})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"channel": strings.ToLower(strings.TrimSpace(channel)),
|
||||
"models": models,
|
||||
})
|
||||
}
|
||||
100
internal/api/handlers/management/oauth_callback.go
Normal file
100
internal/api/handlers/management/oauth_callback.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type oauthCallbackRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
RedirectURL string `json:"redirect_url"`
|
||||
Code string `json:"code"`
|
||||
State string `json:"state"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func (h *Handler) PostOAuthCallback(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "handler not initialized"})
|
||||
return
|
||||
}
|
||||
|
||||
var req oauthCallbackRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
canonicalProvider, err := NormalizeOAuthProvider(req.Provider)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "unsupported provider"})
|
||||
return
|
||||
}
|
||||
|
||||
state := strings.TrimSpace(req.State)
|
||||
code := strings.TrimSpace(req.Code)
|
||||
errMsg := strings.TrimSpace(req.Error)
|
||||
|
||||
if rawRedirect := strings.TrimSpace(req.RedirectURL); rawRedirect != "" {
|
||||
u, errParse := url.Parse(rawRedirect)
|
||||
if errParse != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid redirect_url"})
|
||||
return
|
||||
}
|
||||
q := u.Query()
|
||||
if state == "" {
|
||||
state = strings.TrimSpace(q.Get("state"))
|
||||
}
|
||||
if code == "" {
|
||||
code = strings.TrimSpace(q.Get("code"))
|
||||
}
|
||||
if errMsg == "" {
|
||||
errMsg = strings.TrimSpace(q.Get("error"))
|
||||
if errMsg == "" {
|
||||
errMsg = strings.TrimSpace(q.Get("error_description"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"})
|
||||
return
|
||||
}
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
|
||||
return
|
||||
}
|
||||
if code == "" && errMsg == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "code or error is required"})
|
||||
return
|
||||
}
|
||||
|
||||
sessionProvider, sessionStatus, ok := GetOAuthSession(state)
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"status": "error", "error": "unknown or expired state"})
|
||||
return
|
||||
}
|
||||
if sessionStatus != "" {
|
||||
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(sessionProvider, canonicalProvider) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "provider does not match state"})
|
||||
return
|
||||
}
|
||||
|
||||
if _, errWrite := WriteOAuthCallbackFileForPendingSession(h.cfg.AuthDir, canonicalProvider, state, code, errMsg); errWrite != nil {
|
||||
if errors.Is(errWrite, errOAuthSessionNotPending) {
|
||||
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to persist oauth callback"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
}
|
||||
292
internal/api/handlers/management/oauth_sessions.go
Normal file
292
internal/api/handlers/management/oauth_sessions.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
oauthSessionTTL = 10 * time.Minute
|
||||
maxOAuthStateLength = 128
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidOAuthState = errors.New("invalid oauth state")
|
||||
errUnsupportedOAuthFlow = errors.New("unsupported oauth provider")
|
||||
errOAuthSessionNotPending = errors.New("oauth session is not pending")
|
||||
)
|
||||
|
||||
type oauthSession struct {
|
||||
Provider string
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type oauthSessionStore struct {
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
sessions map[string]oauthSession
|
||||
}
|
||||
|
||||
func newOAuthSessionStore(ttl time.Duration) *oauthSessionStore {
|
||||
if ttl <= 0 {
|
||||
ttl = oauthSessionTTL
|
||||
}
|
||||
return &oauthSessionStore{
|
||||
ttl: ttl,
|
||||
sessions: make(map[string]oauthSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) purgeExpiredLocked(now time.Time) {
|
||||
for state, session := range s.sessions {
|
||||
if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
|
||||
delete(s.sessions, state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Register(state, provider string) {
|
||||
state = strings.TrimSpace(state)
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if state == "" || provider == "" {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
s.sessions[state] = oauthSession{
|
||||
Provider: provider,
|
||||
Status: "",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(s.ttl),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) SetError(state, message string) {
|
||||
state = strings.TrimSpace(state)
|
||||
message = strings.TrimSpace(message)
|
||||
if state == "" {
|
||||
return
|
||||
}
|
||||
if message == "" {
|
||||
message = "Authentication failed"
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
session.Status = message
|
||||
session.ExpiresAt = now.Add(s.ttl)
|
||||
s.sessions[state] = session
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Complete(state string) {
|
||||
state = strings.TrimSpace(state)
|
||||
if state == "" {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
delete(s.sessions, state)
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) CompleteProvider(provider string) int {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if provider == "" {
|
||||
return 0
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
removed := 0
|
||||
for state, session := range s.sessions {
|
||||
if strings.EqualFold(session.Provider, provider) {
|
||||
delete(s.sessions, state)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
return removed
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Get(state string) (oauthSession, bool) {
|
||||
state = strings.TrimSpace(state)
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
return session, ok
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) IsPending(state, provider string) bool {
|
||||
state = strings.TrimSpace(state)
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if session.Status != "" {
|
||||
if !strings.EqualFold(session.Provider, "kiro") {
|
||||
return false
|
||||
}
|
||||
if !strings.HasPrefix(session.Status, "device_code|") && !strings.HasPrefix(session.Status, "auth_url|") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if provider == "" {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(session.Provider, provider)
|
||||
}
|
||||
|
||||
var oauthSessions = newOAuthSessionStore(oauthSessionTTL)
|
||||
|
||||
func RegisterOAuthSession(state, provider string) { oauthSessions.Register(state, provider) }
|
||||
|
||||
func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, message) }
|
||||
|
||||
func CompleteOAuthSession(state string) { oauthSessions.Complete(state) }
|
||||
|
||||
func CompleteOAuthSessionsByProvider(provider string) int {
|
||||
return oauthSessions.CompleteProvider(provider)
|
||||
}
|
||||
|
||||
func GetOAuthSession(state string) (provider string, status string, ok bool) {
|
||||
session, ok := oauthSessions.Get(state)
|
||||
if !ok {
|
||||
return "", "", false
|
||||
}
|
||||
return session.Provider, session.Status, true
|
||||
}
|
||||
|
||||
func IsOAuthSessionPending(state, provider string) bool {
|
||||
return oauthSessions.IsPending(state, provider)
|
||||
}
|
||||
|
||||
func ValidateOAuthState(state string) error {
|
||||
trimmed := strings.TrimSpace(state)
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("%w: empty", errInvalidOAuthState)
|
||||
}
|
||||
if len(trimmed) > maxOAuthStateLength {
|
||||
return fmt.Errorf("%w: too long", errInvalidOAuthState)
|
||||
}
|
||||
if strings.Contains(trimmed, "/") || strings.Contains(trimmed, "\\") {
|
||||
return fmt.Errorf("%w: contains path separator", errInvalidOAuthState)
|
||||
}
|
||||
if strings.Contains(trimmed, "..") {
|
||||
return fmt.Errorf("%w: contains '..'", errInvalidOAuthState)
|
||||
}
|
||||
for _, r := range trimmed {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r >= '0' && r <= '9':
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
default:
|
||||
return fmt.Errorf("%w: invalid character", errInvalidOAuthState)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NormalizeOAuthProvider(provider string) (string, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "anthropic", "claude":
|
||||
return "anthropic", nil
|
||||
case "codex", "openai":
|
||||
return "codex", nil
|
||||
case "gemini", "google":
|
||||
return "gemini", nil
|
||||
case "iflow", "i-flow":
|
||||
return "iflow", nil
|
||||
case "antigravity", "anti-gravity":
|
||||
return "antigravity", nil
|
||||
case "qwen":
|
||||
return "qwen", nil
|
||||
case "kiro":
|
||||
return "kiro", nil
|
||||
case "github":
|
||||
return "github", nil
|
||||
default:
|
||||
return "", errUnsupportedOAuthFlow
|
||||
}
|
||||
}
|
||||
|
||||
type oauthCallbackFilePayload struct {
|
||||
Code string `json:"code"`
|
||||
State string `json:"state"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) {
|
||||
if strings.TrimSpace(authDir) == "" {
|
||||
return "", fmt.Errorf("auth dir is empty")
|
||||
}
|
||||
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf(".oauth-%s-%s.oauth", canonicalProvider, state)
|
||||
filePath := filepath.Join(authDir, fileName)
|
||||
payload := oauthCallbackFilePayload{
|
||||
Code: strings.TrimSpace(code),
|
||||
State: strings.TrimSpace(state),
|
||||
Error: strings.TrimSpace(errorMessage),
|
||||
}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal oauth callback payload: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(filePath, data, 0o600); err != nil {
|
||||
return "", fmt.Errorf("write oauth callback file: %w", err)
|
||||
}
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) {
|
||||
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !IsOAuthSessionPending(state, canonicalProvider) {
|
||||
return "", errOAuthSessionNotPending
|
||||
}
|
||||
return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage)
|
||||
}
|
||||
@@ -1,12 +1,25 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
)
|
||||
|
||||
type usageExportPayload struct {
|
||||
Version int `json:"version"`
|
||||
ExportedAt time.Time `json:"exported_at"`
|
||||
Usage usage.StatisticsSnapshot `json:"usage"`
|
||||
}
|
||||
|
||||
type usageImportPayload struct {
|
||||
Version int `json:"version"`
|
||||
Usage usage.StatisticsSnapshot `json:"usage"`
|
||||
}
|
||||
|
||||
// GetUsageStatistics returns the in-memory request statistics snapshot.
|
||||
func (h *Handler) GetUsageStatistics(c *gin.Context) {
|
||||
var snapshot usage.StatisticsSnapshot
|
||||
@@ -18,3 +31,49 @@ func (h *Handler) GetUsageStatistics(c *gin.Context) {
|
||||
"failed_requests": snapshot.FailureCount,
|
||||
})
|
||||
}
|
||||
|
||||
// ExportUsageStatistics returns a complete usage snapshot for backup/migration.
|
||||
func (h *Handler) ExportUsageStatistics(c *gin.Context) {
|
||||
var snapshot usage.StatisticsSnapshot
|
||||
if h != nil && h.usageStats != nil {
|
||||
snapshot = h.usageStats.Snapshot()
|
||||
}
|
||||
c.JSON(http.StatusOK, usageExportPayload{
|
||||
Version: 1,
|
||||
ExportedAt: time.Now().UTC(),
|
||||
Usage: snapshot,
|
||||
})
|
||||
}
|
||||
|
||||
// ImportUsageStatistics merges a previously exported usage snapshot into memory.
|
||||
func (h *Handler) ImportUsageStatistics(c *gin.Context) {
|
||||
if h == nil || h.usageStats == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
data, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
|
||||
return
|
||||
}
|
||||
|
||||
var payload usageImportPayload
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
|
||||
return
|
||||
}
|
||||
if payload.Version != 0 && payload.Version != 1 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
|
||||
return
|
||||
}
|
||||
|
||||
result := h.usageStats.MergeSnapshot(payload.Usage)
|
||||
snapshot := h.usageStats.Snapshot()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"added": result.Added,
|
||||
"skipped": result.Skipped,
|
||||
"total_requests": snapshot.TotalRequests,
|
||||
"failed_requests": snapshot.FailureCount,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,16 +8,19 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB
|
||||
|
||||
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
||||
// It captures detailed information about the request and response, including headers and body,
|
||||
// and uses the provided RequestLogger to record this data. When logging is disabled in the
|
||||
// logger, it still captures data so that upstream errors can be persisted.
|
||||
// and uses the provided RequestLogger to record this data. When full request logging is disabled,
|
||||
// body capture is limited to small known-size payloads to avoid large per-request memory spikes.
|
||||
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if logger == nil {
|
||||
@@ -25,7 +28,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
if c.Request.Method == http.MethodGet {
|
||||
if shouldSkipMethodForRequestLogging(c.Request) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -36,8 +39,10 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
loggerEnabled := logger.IsEnabled()
|
||||
|
||||
// Capture request information
|
||||
requestInfo, err := captureRequestInfo(c)
|
||||
requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request))
|
||||
if err != nil {
|
||||
// Log error but continue processing
|
||||
// In a real implementation, you might want to use a proper logger here
|
||||
@@ -47,7 +52,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
|
||||
// Create response writer wrapper
|
||||
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
|
||||
if !logger.IsEnabled() {
|
||||
if !loggerEnabled {
|
||||
wrapper.logOnErrorOnly = true
|
||||
}
|
||||
c.Writer = wrapper
|
||||
@@ -63,10 +68,47 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func shouldSkipMethodForRequestLogging(req *http.Request) bool {
|
||||
if req == nil {
|
||||
return true
|
||||
}
|
||||
if req.Method != http.MethodGet {
|
||||
return false
|
||||
}
|
||||
return !isResponsesWebsocketUpgrade(req)
|
||||
}
|
||||
|
||||
func isResponsesWebsocketUpgrade(req *http.Request) bool {
|
||||
if req == nil || req.URL == nil {
|
||||
return false
|
||||
}
|
||||
if req.URL.Path != "/v1/responses" {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket")
|
||||
}
|
||||
|
||||
func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool {
|
||||
if loggerEnabled {
|
||||
return true
|
||||
}
|
||||
if req == nil || req.Body == nil {
|
||||
return false
|
||||
}
|
||||
contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type")))
|
||||
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
return false
|
||||
}
|
||||
if req.ContentLength <= 0 {
|
||||
return false
|
||||
}
|
||||
return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes
|
||||
}
|
||||
|
||||
// captureRequestInfo extracts relevant information from the incoming HTTP request.
|
||||
// It captures the URL, method, headers, and body. The request body is read and then
|
||||
// restored so that it can be processed by subsequent handlers.
|
||||
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) {
|
||||
// Capture URL with sensitive query parameters masked
|
||||
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||
url := c.Request.URL.Path
|
||||
@@ -85,7 +127,7 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
|
||||
// Capture request body
|
||||
var body []byte
|
||||
if c.Request.Body != nil {
|
||||
if captureBody && c.Request.Body != nil {
|
||||
// Read the body
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
@@ -98,10 +140,12 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
}
|
||||
|
||||
return &RequestInfo{
|
||||
URL: url,
|
||||
Method: method,
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
URL: url,
|
||||
Method: method,
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
RequestID: logging.GetGinRequestID(c),
|
||||
Timestamp: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -112,5 +156,10 @@ func shouldLogRequest(path string) bool {
|
||||
if strings.HasPrefix(path, "/v0/management") || strings.HasPrefix(path, "/management") {
|
||||
return false
|
||||
}
|
||||
|
||||
if strings.HasPrefix(path, "/api") {
|
||||
return strings.HasPrefix(path, "/api/provider")
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
138
internal/api/middleware/request_logging_test.go
Normal file
138
internal/api/middleware/request_logging_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *http.Request
|
||||
skip bool
|
||||
}{
|
||||
{
|
||||
name: "nil request",
|
||||
req: nil,
|
||||
skip: true,
|
||||
},
|
||||
{
|
||||
name: "post request should not skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodPost,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
},
|
||||
skip: false,
|
||||
},
|
||||
{
|
||||
name: "plain get should skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/models"},
|
||||
Header: http.Header{},
|
||||
},
|
||||
skip: true,
|
||||
},
|
||||
{
|
||||
name: "responses websocket upgrade should not skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
Header: http.Header{"Upgrade": []string{"websocket"}},
|
||||
},
|
||||
skip: false,
|
||||
},
|
||||
{
|
||||
name: "responses get without upgrade should skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
Header: http.Header{},
|
||||
},
|
||||
skip: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i := range tests {
|
||||
got := shouldSkipMethodForRequestLogging(tests[i].req)
|
||||
if got != tests[i].skip {
|
||||
t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldCaptureRequestBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
loggerEnabled bool
|
||||
req *http.Request
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "logger enabled always captures",
|
||||
loggerEnabled: true,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("{}")),
|
||||
ContentLength: -1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "nil request",
|
||||
loggerEnabled: false,
|
||||
req: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "small known size json in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("{}")),
|
||||
ContentLength: 2,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "large known size skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "unknown size skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: -1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "multipart skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: 1,
|
||||
Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i := range tests {
|
||||
got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req)
|
||||
if got != tests[i].want {
|
||||
t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7,34 +7,40 @@ import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
)
|
||||
|
||||
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
|
||||
|
||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||
type RequestInfo struct {
|
||||
URL string // URL is the request URL.
|
||||
Method string // Method is the HTTP method (e.g., GET, POST).
|
||||
Headers map[string][]string // Headers contains the request headers.
|
||||
Body []byte // Body is the raw request body.
|
||||
URL string // URL is the request URL.
|
||||
Method string // Method is the HTTP method (e.g., GET, POST).
|
||||
Headers map[string][]string // Headers contains the request headers.
|
||||
Body []byte // Body is the raw request body.
|
||||
RequestID string // RequestID is the unique identifier for the request.
|
||||
Timestamp time.Time // Timestamp is when the request was received.
|
||||
}
|
||||
|
||||
// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
|
||||
// It is designed to handle both standard and streaming responses, ensuring that logging operations do not block the client response.
|
||||
type ResponseWriterWrapper struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
|
||||
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
|
||||
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
|
||||
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
|
||||
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
|
||||
logger logging.RequestLogger // logger is the instance of the request logger service.
|
||||
requestInfo *RequestInfo // requestInfo holds the details of the original request.
|
||||
statusCode int // statusCode stores the HTTP status code of the response.
|
||||
headers map[string][]string // headers stores the response headers.
|
||||
logOnErrorOnly bool // logOnErrorOnly enables logging only when an error response is detected.
|
||||
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
|
||||
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
|
||||
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
|
||||
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
|
||||
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
|
||||
logger logging.RequestLogger // logger is the instance of the request logger service.
|
||||
requestInfo *RequestInfo // requestInfo holds the details of the original request.
|
||||
statusCode int // statusCode stores the HTTP status code of the response.
|
||||
headers map[string][]string // headers stores the response headers.
|
||||
logOnErrorOnly bool // logOnErrorOnly enables logging only when an error response is detected.
|
||||
firstChunkTimestamp time.Time // firstChunkTimestamp captures TTFB for streaming responses.
|
||||
}
|
||||
|
||||
// NewResponseWriterWrapper creates and initializes a new ResponseWriterWrapper.
|
||||
@@ -71,22 +77,72 @@ func (w *ResponseWriterWrapper) Write(data []byte) (int, error) {
|
||||
n, err := w.ResponseWriter.Write(data)
|
||||
|
||||
// THEN: Handle logging based on response type
|
||||
if w.isStreaming {
|
||||
// For streaming responses: Send to async logging channel (non-blocking)
|
||||
if w.chunkChannel != nil {
|
||||
select {
|
||||
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
|
||||
default: // Channel full, skip logging to avoid blocking
|
||||
}
|
||||
if w.isStreaming && w.chunkChannel != nil {
|
||||
// Capture TTFB on first chunk (synchronous, before async channel send)
|
||||
if w.firstChunkTimestamp.IsZero() {
|
||||
w.firstChunkTimestamp = time.Now()
|
||||
}
|
||||
} else {
|
||||
// For non-streaming responses: Buffer complete response
|
||||
// For streaming responses: Send to async logging channel (non-blocking)
|
||||
select {
|
||||
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
|
||||
default: // Channel full, skip logging to avoid blocking
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
if w.shouldBufferResponseBody() {
|
||||
w.body.Write(data)
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) shouldBufferResponseBody() bool {
|
||||
if w.logger != nil && w.logger.IsEnabled() {
|
||||
return true
|
||||
}
|
||||
if !w.logOnErrorOnly {
|
||||
return false
|
||||
}
|
||||
status := w.statusCode
|
||||
if status == 0 {
|
||||
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok && statusWriter != nil {
|
||||
status = statusWriter.Status()
|
||||
} else {
|
||||
status = http.StatusOK
|
||||
}
|
||||
}
|
||||
return status >= http.StatusBadRequest
|
||||
}
|
||||
|
||||
// WriteString wraps the underlying ResponseWriter's WriteString method to capture response data.
|
||||
// Some handlers (and fmt/io helpers) write via io.StringWriter; without this override, those writes
|
||||
// bypass Write() and would be missing from request logs.
|
||||
func (w *ResponseWriterWrapper) WriteString(data string) (int, error) {
|
||||
w.ensureHeadersCaptured()
|
||||
|
||||
// CRITICAL: Write to client first (zero latency)
|
||||
n, err := w.ResponseWriter.WriteString(data)
|
||||
|
||||
// THEN: Capture for logging
|
||||
if w.isStreaming && w.chunkChannel != nil {
|
||||
// Capture TTFB on first chunk (synchronous, before async channel send)
|
||||
if w.firstChunkTimestamp.IsZero() {
|
||||
w.firstChunkTimestamp = time.Now()
|
||||
}
|
||||
select {
|
||||
case w.chunkChannel <- []byte(data):
|
||||
default:
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
if w.shouldBufferResponseBody() {
|
||||
w.body.WriteString(data)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// WriteHeader wraps the underlying ResponseWriter's WriteHeader method.
|
||||
// It captures the status code, detects if the response is streaming based on the Content-Type header,
|
||||
// and initializes the appropriate logging mechanism (standard or streaming).
|
||||
@@ -107,6 +163,7 @@ func (w *ResponseWriterWrapper) WriteHeader(statusCode int) {
|
||||
w.requestInfo.Method,
|
||||
w.requestInfo.Headers,
|
||||
w.requestInfo.Body,
|
||||
w.requestInfo.RequestID,
|
||||
)
|
||||
if err == nil {
|
||||
w.streamWriter = streamWriter
|
||||
@@ -160,12 +217,16 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check request body for streaming indicators
|
||||
if w.requestInfo.Body != nil {
|
||||
bodyStr := string(w.requestInfo.Body)
|
||||
if strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`) {
|
||||
return true
|
||||
}
|
||||
// If a concrete Content-Type is already set (e.g., application/json for error responses),
|
||||
// treat it as non-streaming instead of inferring from the request payload.
|
||||
if strings.TrimSpace(contentType) != "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only fall back to request payload hints when Content-Type is not set yet.
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
return bytes.Contains(w.requestInfo.Body, []byte(`"stream": true`)) ||
|
||||
bytes.Contains(w.requestInfo.Body, []byte(`"stream":true`))
|
||||
}
|
||||
|
||||
return false
|
||||
@@ -221,7 +282,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if w.isStreaming {
|
||||
if w.isStreaming && w.streamWriter != nil {
|
||||
if w.chunkChannel != nil {
|
||||
close(w.chunkChannel)
|
||||
w.chunkChannel = nil
|
||||
@@ -232,20 +293,26 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
w.streamDone = nil
|
||||
}
|
||||
|
||||
if w.streamWriter != nil {
|
||||
if err := w.streamWriter.Close(); err != nil {
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
}
|
||||
w.streamWriter.SetFirstChunkTimestamp(w.firstChunkTimestamp)
|
||||
|
||||
// Write API Request and Response to the streaming log before closing
|
||||
apiRequest := w.extractAPIRequest(c)
|
||||
if len(apiRequest) > 0 {
|
||||
_ = w.streamWriter.WriteAPIRequest(apiRequest)
|
||||
}
|
||||
apiResponse := w.extractAPIResponse(c)
|
||||
if len(apiResponse) > 0 {
|
||||
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
||||
}
|
||||
if err := w.streamWriter.Close(); err != nil {
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
}
|
||||
if forceLog {
|
||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
w.streamWriter = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog)
|
||||
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
||||
@@ -285,18 +352,45 @@ func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
|
||||
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
|
||||
if !isExist {
|
||||
return time.Time{}
|
||||
}
|
||||
if t, ok := ts.(time.Time); ok {
|
||||
return t
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
||||
if c != nil {
|
||||
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
|
||||
switch value := bodyOverride.(type) {
|
||||
case []byte:
|
||||
if len(value) > 0 {
|
||||
return bytes.Clone(value)
|
||||
}
|
||||
case string:
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return []byte(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
return w.requestInfo.Body
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
if w.requestInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var requestBody []byte
|
||||
if len(w.requestInfo.Body) > 0 {
|
||||
requestBody = w.requestInfo.Body
|
||||
}
|
||||
|
||||
if loggerWithOptions, ok := w.logger.(interface {
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool) error
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
||||
}); ok {
|
||||
return loggerWithOptions.LogRequestWithOptions(
|
||||
w.requestInfo.URL,
|
||||
@@ -310,6 +404,9 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
||||
apiResponseBody,
|
||||
apiResponseErrors,
|
||||
forceLog,
|
||||
w.requestInfo.RequestID,
|
||||
w.requestInfo.Timestamp,
|
||||
apiResponseTimestamp,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -324,28 +421,8 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
||||
apiRequestBody,
|
||||
apiResponseBody,
|
||||
apiResponseErrors,
|
||||
w.requestInfo.RequestID,
|
||||
w.requestInfo.Timestamp,
|
||||
apiResponseTimestamp,
|
||||
)
|
||||
}
|
||||
|
||||
// Status returns the HTTP response status code captured by the wrapper.
|
||||
// It defaults to 200 if WriteHeader has not been called.
|
||||
func (w *ResponseWriterWrapper) Status() int {
|
||||
if w.statusCode == 0 {
|
||||
return 200 // Default status code
|
||||
}
|
||||
return w.statusCode
|
||||
}
|
||||
|
||||
// Size returns the size of the response body in bytes for non-streaming responses.
|
||||
// For streaming responses, it returns -1, as the total size is unknown.
|
||||
func (w *ResponseWriterWrapper) Size() int {
|
||||
if w.isStreaming {
|
||||
return -1 // Unknown size for streaming responses
|
||||
}
|
||||
return w.body.Len()
|
||||
}
|
||||
|
||||
// Written returns true if the response header has been written (i.e., a status code has been set).
|
||||
func (w *ResponseWriterWrapper) Written() bool {
|
||||
return w.statusCode != 0
|
||||
}
|
||||
|
||||
43
internal/api/middleware/response_writer_test.go
Normal file
43
internal/api/middleware/response_writer_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{
|
||||
requestInfo: &RequestInfo{Body: []byte("original-body")},
|
||||
}
|
||||
|
||||
body := wrapper.extractRequestBody(c)
|
||||
if string(body) != "original-body" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "original-body")
|
||||
}
|
||||
|
||||
c.Set(requestBodyOverrideContextKey, []byte("override-body"))
|
||||
body = wrapper.extractRequestBody(c)
|
||||
if string(body) != "override-body" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "override-body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{}
|
||||
c.Set(requestBodyOverrideContextKey, "override-as-string")
|
||||
|
||||
body := wrapper.extractRequestBody(c)
|
||||
if string(body) != "override-as-string" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
|
||||
}
|
||||
}
|
||||
@@ -27,11 +27,20 @@ type Option func(*AmpModule)
|
||||
type AmpModule struct {
|
||||
secretSource SecretSource
|
||||
proxy *httputil.ReverseProxy
|
||||
proxyMu sync.RWMutex // protects proxy for hot-reload
|
||||
accessManager *sdkaccess.Manager
|
||||
authMiddleware_ gin.HandlerFunc
|
||||
modelMapper *DefaultModelMapper
|
||||
enabled bool
|
||||
registerOnce sync.Once
|
||||
|
||||
// restrictToLocalhost controls localhost-only access for management routes (hot-reloadable)
|
||||
restrictToLocalhost bool
|
||||
restrictMu sync.RWMutex
|
||||
|
||||
// configMu protects lastConfig for partial reload comparison
|
||||
configMu sync.RWMutex
|
||||
lastConfig *config.AmpCode
|
||||
}
|
||||
|
||||
// New creates a new Amp routing module with the given options.
|
||||
@@ -91,6 +100,16 @@ func (m *AmpModule) Name() string {
|
||||
return "amp-routing"
|
||||
}
|
||||
|
||||
// forceModelMappings returns whether model mappings should take precedence over local API keys
|
||||
func (m *AmpModule) forceModelMappings() bool {
|
||||
m.configMu.RLock()
|
||||
defer m.configMu.RUnlock()
|
||||
if m.lastConfig == nil {
|
||||
return false
|
||||
}
|
||||
return m.lastConfig.ForceModelMappings
|
||||
}
|
||||
|
||||
// Register sets up Amp routes if configured.
|
||||
// This implements the RouteModuleV2 interface with Context.
|
||||
// Routes are registered only once via sync.Once for idempotent behavior.
|
||||
@@ -107,9 +126,19 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
// Initialize model mapper from config (for routing unavailable models to alternatives)
|
||||
m.modelMapper = NewModelMapper(settings.ModelMappings)
|
||||
|
||||
// Store initial config for partial reload comparison
|
||||
m.lastConfig = new(settings)
|
||||
|
||||
// Initialize localhost restriction setting (hot-reloadable)
|
||||
m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost)
|
||||
|
||||
// Always register provider aliases - these work without an upstream
|
||||
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
|
||||
|
||||
// Register management proxy routes once; middleware will gate access when upstream is unavailable.
|
||||
// Pass auth middleware to require valid API key for all management routes.
|
||||
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, auth)
|
||||
|
||||
// If no upstream URL, skip proxy routes but provider aliases are still available
|
||||
if upstreamURL == "" {
|
||||
log.Debug("amp upstream proxy disabled (no upstream URL configured)")
|
||||
@@ -118,28 +147,11 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
return
|
||||
}
|
||||
|
||||
// Create secret source with precedence: config > env > file
|
||||
// Cache secrets for 5 minutes to reduce file I/O
|
||||
if m.secretSource == nil {
|
||||
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
|
||||
}
|
||||
|
||||
// Create reverse proxy with gzip handling via ModifyResponse
|
||||
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
|
||||
if err != nil {
|
||||
if err := m.enableUpstreamProxy(upstreamURL, &settings); err != nil {
|
||||
regErr = fmt.Errorf("failed to create amp proxy: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.proxy = proxy
|
||||
m.enabled = true
|
||||
|
||||
// Register management proxy routes (requires upstream)
|
||||
// Restrict to localhost by default for security (prevents drive-by browser attacks)
|
||||
handler := proxyHandler(proxy)
|
||||
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, handler, settings.RestrictManagementToLocalhost)
|
||||
|
||||
log.Infof("amp upstream proxy enabled for: %s", upstreamURL)
|
||||
log.Debug("amp provider alias routes registered")
|
||||
})
|
||||
|
||||
@@ -162,44 +174,254 @@ func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// OnConfigUpdated handles configuration updates.
|
||||
// Currently requires restart for URL changes (could be enhanced for dynamic updates).
|
||||
// OnConfigUpdated handles configuration updates with partial reload support.
|
||||
// Only updates components that have actually changed to avoid unnecessary work.
|
||||
// Supports hot-reload for: model-mappings, upstream-api-key, upstream-url, restrict-management-to-localhost.
|
||||
func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
settings := cfg.AmpCode
|
||||
newSettings := cfg.AmpCode
|
||||
|
||||
// Update model mappings (hot-reload supported)
|
||||
if m.modelMapper != nil {
|
||||
m.modelMapper.UpdateMappings(settings.ModelMappings)
|
||||
if m.enabled {
|
||||
log.Infof("amp config updated: reloading %d model mapping(s)", len(settings.ModelMappings))
|
||||
}
|
||||
} else if m.enabled {
|
||||
log.Warnf("amp model mapper not initialized, skipping model mapping update")
|
||||
// Get previous config for comparison
|
||||
m.configMu.RLock()
|
||||
oldSettings := m.lastConfig
|
||||
m.configMu.RUnlock()
|
||||
|
||||
if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
|
||||
m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost)
|
||||
}
|
||||
|
||||
if !m.enabled {
|
||||
return nil
|
||||
newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
|
||||
oldUpstreamURL := ""
|
||||
if oldSettings != nil {
|
||||
oldUpstreamURL = strings.TrimSpace(oldSettings.UpstreamURL)
|
||||
}
|
||||
|
||||
upstreamURL := strings.TrimSpace(settings.UpstreamURL)
|
||||
if upstreamURL == "" {
|
||||
log.Warn("amp upstream URL removed from config, restart required to disable")
|
||||
return nil
|
||||
}
|
||||
|
||||
// If API key changed, invalidate the cache
|
||||
if m.secretSource != nil {
|
||||
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||
ms.InvalidateCache()
|
||||
log.Debug("amp secret cache invalidated due to config update")
|
||||
if !m.enabled && newUpstreamURL != "" {
|
||||
if err := m.enableUpstreamProxy(newUpstreamURL, &newSettings); err != nil {
|
||||
log.Errorf("amp config: failed to enable upstream proxy for %s: %v", newUpstreamURL, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("amp config updated (restart required for URL changes)")
|
||||
// Check model mappings change
|
||||
modelMappingsChanged := m.hasModelMappingsChanged(oldSettings, &newSettings)
|
||||
if modelMappingsChanged {
|
||||
if m.modelMapper != nil {
|
||||
m.modelMapper.UpdateMappings(newSettings.ModelMappings)
|
||||
} else if m.enabled {
|
||||
log.Warnf("amp model mapper not initialized, skipping model mapping update")
|
||||
}
|
||||
}
|
||||
|
||||
if m.enabled {
|
||||
// Check upstream URL change - now supports hot-reload
|
||||
if newUpstreamURL == "" && oldUpstreamURL != "" {
|
||||
m.setProxy(nil)
|
||||
m.enabled = false
|
||||
} else if oldUpstreamURL != "" && newUpstreamURL != oldUpstreamURL && newUpstreamURL != "" {
|
||||
// Recreate proxy with new URL
|
||||
proxy, err := createReverseProxy(newUpstreamURL, m.secretSource)
|
||||
if err != nil {
|
||||
log.Errorf("amp config: failed to create proxy for new upstream URL %s: %v", newUpstreamURL, err)
|
||||
} else {
|
||||
m.setProxy(proxy)
|
||||
}
|
||||
}
|
||||
|
||||
// Check API key change (both default and per-client mappings)
|
||||
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
|
||||
upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings)
|
||||
if apiKeyChanged || upstreamAPIKeysChanged {
|
||||
if m.secretSource != nil {
|
||||
if ms, ok := m.secretSource.(*MappedSecretSource); ok {
|
||||
if apiKeyChanged {
|
||||
ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
}
|
||||
if upstreamAPIKeysChanged {
|
||||
ms.UpdateMappings(newSettings.UpstreamAPIKeys)
|
||||
}
|
||||
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Store current config for next comparison
|
||||
m.configMu.Lock()
|
||||
settingsCopy := newSettings // copy struct
|
||||
m.lastConfig = &settingsCopy
|
||||
m.configMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
|
||||
if m.secretSource == nil {
|
||||
// Create MultiSourceSecret as the default source, then wrap with MappedSecretSource
|
||||
defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
|
||||
mappedSource := NewMappedSecretSource(defaultSource)
|
||||
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
|
||||
m.secretSource = mappedSource
|
||||
} else if ms, ok := m.secretSource.(*MappedSecretSource); ok {
|
||||
ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
ms.UpdateMappings(settings.UpstreamAPIKeys)
|
||||
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||
// Legacy path: wrap existing MultiSourceSecret with MappedSecretSource
|
||||
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
mappedSource := NewMappedSecretSource(ms)
|
||||
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
|
||||
m.secretSource = mappedSource
|
||||
}
|
||||
|
||||
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.setProxy(proxy)
|
||||
m.enabled = true
|
||||
|
||||
log.Infof("amp upstream proxy enabled for: %s", upstreamURL)
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasModelMappingsChanged compares old and new model mappings.
|
||||
func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.AmpCode) bool {
|
||||
if old == nil {
|
||||
return len(new.ModelMappings) > 0
|
||||
}
|
||||
|
||||
if len(old.ModelMappings) != len(new.ModelMappings) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Build map for efficient and robust comparison
|
||||
type mappingInfo struct {
|
||||
to string
|
||||
regex bool
|
||||
}
|
||||
oldMap := make(map[string]mappingInfo, len(old.ModelMappings))
|
||||
for _, mapping := range old.ModelMappings {
|
||||
oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{
|
||||
to: strings.TrimSpace(mapping.To),
|
||||
regex: mapping.Regex,
|
||||
}
|
||||
}
|
||||
|
||||
for _, mapping := range new.ModelMappings {
|
||||
from := strings.TrimSpace(mapping.From)
|
||||
to := strings.TrimSpace(mapping.To)
|
||||
if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// hasAPIKeyChanged compares old and new API keys.
|
||||
func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) bool {
|
||||
oldKey := ""
|
||||
if old != nil {
|
||||
oldKey = strings.TrimSpace(old.UpstreamAPIKey)
|
||||
}
|
||||
newKey := strings.TrimSpace(new.UpstreamAPIKey)
|
||||
return oldKey != newKey
|
||||
}
|
||||
|
||||
// hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings.
|
||||
func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool {
|
||||
if old == nil {
|
||||
return len(new.UpstreamAPIKeys) > 0
|
||||
}
|
||||
|
||||
if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Build map for comparison: upstreamKey -> set of clientKeys
|
||||
type entryInfo struct {
|
||||
upstreamKey string
|
||||
clientKeys map[string]struct{}
|
||||
}
|
||||
oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys))
|
||||
for i, entry := range old.UpstreamAPIKeys {
|
||||
clientKeys := make(map[string]struct{}, len(entry.APIKeys))
|
||||
for _, k := range entry.APIKeys {
|
||||
trimmed := strings.TrimSpace(k)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
clientKeys[trimmed] = struct{}{}
|
||||
}
|
||||
oldEntries[i] = entryInfo{
|
||||
upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey),
|
||||
clientKeys: clientKeys,
|
||||
}
|
||||
}
|
||||
|
||||
for i, newEntry := range new.UpstreamAPIKeys {
|
||||
if i >= len(oldEntries) {
|
||||
return true
|
||||
}
|
||||
oldE := oldEntries[i]
|
||||
if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey {
|
||||
return true
|
||||
}
|
||||
newKeys := make(map[string]struct{}, len(newEntry.APIKeys))
|
||||
for _, k := range newEntry.APIKeys {
|
||||
trimmed := strings.TrimSpace(k)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
newKeys[trimmed] = struct{}{}
|
||||
}
|
||||
if len(newKeys) != len(oldE.clientKeys) {
|
||||
return true
|
||||
}
|
||||
for k := range newKeys {
|
||||
if _, ok := oldE.clientKeys[k]; !ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetModelMapper returns the model mapper instance (for testing/debugging).
|
||||
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
|
||||
return m.modelMapper
|
||||
}
|
||||
|
||||
// getProxy returns the current proxy instance (thread-safe for hot-reload).
|
||||
func (m *AmpModule) getProxy() *httputil.ReverseProxy {
|
||||
m.proxyMu.RLock()
|
||||
defer m.proxyMu.RUnlock()
|
||||
return m.proxy
|
||||
}
|
||||
|
||||
// setProxy updates the proxy instance (thread-safe for hot-reload).
|
||||
func (m *AmpModule) setProxy(proxy *httputil.ReverseProxy) {
|
||||
m.proxyMu.Lock()
|
||||
defer m.proxyMu.Unlock()
|
||||
m.proxy = proxy
|
||||
}
|
||||
|
||||
// IsRestrictedToLocalhost returns whether management routes are restricted to localhost.
|
||||
func (m *AmpModule) IsRestrictedToLocalhost() bool {
|
||||
m.restrictMu.RLock()
|
||||
defer m.restrictMu.RUnlock()
|
||||
return m.restrictToLocalhost
|
||||
}
|
||||
|
||||
// setRestrictToLocalhost updates the localhost restriction setting.
|
||||
func (m *AmpModule) setRestrictToLocalhost(restrict bool) {
|
||||
m.restrictMu.Lock()
|
||||
defer m.restrictMu.Unlock()
|
||||
m.restrictToLocalhost = restrict
|
||||
}
|
||||
|
||||
@@ -146,6 +146,9 @@ func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
|
||||
m := &AmpModule{enabled: true}
|
||||
ms := NewMultiSourceSecretWithPath("", p, time.Minute)
|
||||
m.secretSource = ms
|
||||
m.lastConfig = &config.AmpCode{
|
||||
UpstreamAPIKey: "old-key",
|
||||
}
|
||||
|
||||
// Warm the cache
|
||||
if _, err := ms.Get(context.Background()); err != nil {
|
||||
@@ -157,7 +160,7 @@ func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update config - should invalidate cache
|
||||
if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x"}}); err != nil {
|
||||
if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x", UpstreamAPIKey: "new-key"}}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -309,3 +312,41 @@ func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_hasUpstreamAPIKeysChanged_DetectsRemovedKeyWithDuplicateInput(t *testing.T) {
|
||||
m := &AmpModule{}
|
||||
|
||||
oldCfg := &config.AmpCode{
|
||||
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
|
||||
},
|
||||
}
|
||||
newCfg := &config.AmpCode{
|
||||
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k1"}},
|
||||
},
|
||||
}
|
||||
|
||||
if !m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
|
||||
t.Fatal("expected change to be detected when k2 is removed but new list contains duplicates")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpModule_hasUpstreamAPIKeysChanged_IgnoresEmptyAndWhitespaceKeys(t *testing.T) {
|
||||
m := &AmpModule{}
|
||||
|
||||
oldCfg := &config.AmpCode{
|
||||
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
|
||||
},
|
||||
}
|
||||
newCfg := &config.AmpCode{
|
||||
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||
{UpstreamAPIKey: "u1", APIKeys: []string{" k1 ", "", "k2", " "}},
|
||||
},
|
||||
}
|
||||
|
||||
if m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
|
||||
t.Fatal("expected no change when only whitespace/empty entries differ")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,15 +2,17 @@ package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// AmpRouteType represents the type of routing decision made for an Amp request
|
||||
@@ -27,6 +29,9 @@ const (
|
||||
RouteTypeNoProvider AmpRouteType = "NO_PROVIDER"
|
||||
)
|
||||
|
||||
// MappedModelContextKey is the Gin context key for passing mapped model names.
|
||||
const MappedModelContextKey = "mapped_model"
|
||||
|
||||
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
||||
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
||||
fields := log.Fields{
|
||||
@@ -48,48 +53,54 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
|
||||
case RouteTypeLocalProvider:
|
||||
fields["cost"] = "free"
|
||||
fields["source"] = "local_oauth"
|
||||
log.WithFields(fields).Infof("[amp] using local provider for model: %s", requestedModel)
|
||||
log.WithFields(fields).Debugf("amp using local provider for model: %s", requestedModel)
|
||||
|
||||
case RouteTypeModelMapping:
|
||||
fields["cost"] = "free"
|
||||
fields["source"] = "local_oauth"
|
||||
fields["mapping"] = requestedModel + " -> " + resolvedModel
|
||||
log.WithFields(fields).Infof("[amp] model mapped: %s -> %s", requestedModel, resolvedModel)
|
||||
// model mapping already logged in mapper; avoid duplicate here
|
||||
|
||||
case RouteTypeAmpCredits:
|
||||
fields["cost"] = "amp_credits"
|
||||
fields["source"] = "ampcode.com"
|
||||
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
||||
log.WithFields(fields).Warnf("[amp] forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
||||
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local provider, add to config: ampcode.model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
||||
|
||||
case RouteTypeNoProvider:
|
||||
fields["cost"] = "none"
|
||||
fields["source"] = "error"
|
||||
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
||||
log.WithFields(fields).Warnf("[amp] no provider available for model_id: %s", requestedModel)
|
||||
log.WithFields(fields).Warnf("no provider available for model_id: %s", requestedModel)
|
||||
}
|
||||
}
|
||||
|
||||
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
|
||||
// when the model's provider is not available in CLIProxyAPI
|
||||
type FallbackHandler struct {
|
||||
getProxy func() *httputil.ReverseProxy
|
||||
modelMapper ModelMapper
|
||||
getProxy func() *httputil.ReverseProxy
|
||||
modelMapper ModelMapper
|
||||
forceModelMappings func() bool
|
||||
}
|
||||
|
||||
// NewFallbackHandler creates a new fallback handler wrapper
|
||||
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
|
||||
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
|
||||
return &FallbackHandler{
|
||||
getProxy: getProxy,
|
||||
getProxy: getProxy,
|
||||
forceModelMappings: func() bool { return false },
|
||||
}
|
||||
}
|
||||
|
||||
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
|
||||
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper) *FallbackHandler {
|
||||
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
|
||||
if forceModelMappings == nil {
|
||||
forceModelMappings = func() bool { return false }
|
||||
}
|
||||
return &FallbackHandler{
|
||||
getProxy: getProxy,
|
||||
modelMapper: mapper,
|
||||
getProxy: getProxy,
|
||||
modelMapper: mapper,
|
||||
forceModelMappings: forceModelMappings,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,35 +134,93 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
return
|
||||
}
|
||||
|
||||
// Normalize model (handles Gemini thinking suffixes)
|
||||
normalizedModel, _ := util.NormalizeGeminiThinkingModel(modelName)
|
||||
// Normalize model (handles dynamic thinking suffixes)
|
||||
suffixResult := thinking.ParseSuffix(modelName)
|
||||
normalizedModel := suffixResult.ModelName
|
||||
thinkingSuffix := ""
|
||||
if suffixResult.HasSuffix {
|
||||
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
|
||||
}
|
||||
|
||||
// Check if we have providers for this model
|
||||
providers := util.GetProviderName(normalizedModel)
|
||||
resolveMappedModel := func() (string, []string) {
|
||||
if fh.modelMapper == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
mappedModel := fh.modelMapper.MapModel(modelName)
|
||||
if mappedModel == "" {
|
||||
mappedModel = fh.modelMapper.MapModel(normalizedModel)
|
||||
}
|
||||
mappedModel = strings.TrimSpace(mappedModel)
|
||||
if mappedModel == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
|
||||
// already specifies its own thinking suffix.
|
||||
if thinkingSuffix != "" {
|
||||
mappedSuffixResult := thinking.ParseSuffix(mappedModel)
|
||||
if !mappedSuffixResult.HasSuffix {
|
||||
mappedModel += thinkingSuffix
|
||||
}
|
||||
}
|
||||
|
||||
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
|
||||
mappedProviders := util.GetProviderName(mappedBaseModel)
|
||||
if len(mappedProviders) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return mappedModel, mappedProviders
|
||||
}
|
||||
|
||||
// Track resolved model for logging (may change if mapping is applied)
|
||||
resolvedModel := normalizedModel
|
||||
usedMapping := false
|
||||
var providers []string
|
||||
|
||||
if len(providers) == 0 {
|
||||
// No providers configured - check if we have a model mapping
|
||||
if fh.modelMapper != nil {
|
||||
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
||||
// Mapping found - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInBody(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
// Check if model mappings should be forced ahead of local API keys
|
||||
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
|
||||
|
||||
// Get providers for the mapped model
|
||||
providers = util.GetProviderName(mappedModel)
|
||||
|
||||
// Continue to handler with remapped model
|
||||
goto handleRequest
|
||||
}
|
||||
if forceMappings {
|
||||
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
||||
// This allows users to route Amp requests to their preferred OAuth providers
|
||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
|
||||
// No mapping found - check if we have a proxy for fallback
|
||||
// If no mapping applied, check for local providers
|
||||
if !usedMapping {
|
||||
providers = util.GetProviderName(normalizedModel)
|
||||
}
|
||||
} else {
|
||||
// DEFAULT MODE: Check local providers first, then mappings as fallback
|
||||
providers = util.GetProviderName(normalizedModel)
|
||||
|
||||
if len(providers) == 0 {
|
||||
// No providers configured - check if we have a model mapping
|
||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no providers available, fallback to ampcode.com
|
||||
if len(providers) == 0 {
|
||||
proxy := fh.getProxy()
|
||||
if proxy != nil {
|
||||
// Log: Forwarding to ampcode.com (uses Amp credits)
|
||||
@@ -169,8 +238,6 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath)
|
||||
}
|
||||
|
||||
handleRequest:
|
||||
|
||||
// Log the routing decision
|
||||
providerName := ""
|
||||
if len(providers) > 0 {
|
||||
@@ -179,59 +246,62 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
|
||||
if usedMapping {
|
||||
// Log: Model was mapped to another model
|
||||
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
||||
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
||||
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||
c.Writer = rewriter
|
||||
// Filter Anthropic-Beta header only for local handling paths
|
||||
filterAntropicBetaHeader(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
handler(c)
|
||||
rewriter.Flush()
|
||||
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName)
|
||||
} else if len(providers) > 0 {
|
||||
// Log: Using local provider (free)
|
||||
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
||||
// Filter Anthropic-Beta header only for local handling paths
|
||||
filterAntropicBetaHeader(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
handler(c)
|
||||
} else {
|
||||
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
handler(c)
|
||||
}
|
||||
|
||||
// Providers available or no proxy for fallback, restore body and use normal handler
|
||||
// Filter Anthropic-Beta header to remove features requiring special subscription
|
||||
// This is needed when using local providers (bypassing the Amp proxy)
|
||||
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
|
||||
filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07")
|
||||
if filtered != "" {
|
||||
c.Request.Header.Set("Anthropic-Beta", filtered)
|
||||
} else {
|
||||
c.Request.Header.Del("Anthropic-Beta")
|
||||
}
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
handler(c)
|
||||
}
|
||||
}
|
||||
|
||||
// rewriteModelInBody replaces the model name in a JSON request body
|
||||
func rewriteModelInBody(body []byte, newModel string) []byte {
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
log.Warnf("amp model mapping: failed to parse body for rewrite: %v", err)
|
||||
// filterAntropicBetaHeader filters Anthropic-Beta header to remove features requiring special subscription
|
||||
// This is needed when using local providers (bypassing the Amp proxy)
|
||||
func filterAntropicBetaHeader(c *gin.Context) {
|
||||
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
|
||||
if filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07"); filtered != "" {
|
||||
c.Request.Header.Set("Anthropic-Beta", filtered)
|
||||
} else {
|
||||
c.Request.Header.Del("Anthropic-Beta")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// rewriteModelInRequest replaces the model name in a JSON request body
|
||||
func rewriteModelInRequest(body []byte, newModel string) []byte {
|
||||
if !gjson.GetBytes(body, "model").Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
if _, exists := payload["model"]; exists {
|
||||
payload["model"] = newModel
|
||||
newBody, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
log.Warnf("amp model mapping: failed to marshal rewritten body: %v", err)
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
result, err := sjson.SetBytes(body, "model", newModel)
|
||||
if err != nil {
|
||||
log.Warnf("amp model mapping: failed to rewrite model in request body: %v", err)
|
||||
return body
|
||||
}
|
||||
|
||||
return body
|
||||
return result
|
||||
}
|
||||
|
||||
// extractModelFromRequest attempts to extract the model name from various request formats
|
||||
func extractModelFromRequest(body []byte, c *gin.Context) string {
|
||||
// First try to parse from JSON body (OpenAI, Claude, etc.)
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(body, &payload); err == nil {
|
||||
// Check common model field names
|
||||
if model, ok := payload["model"].(string); ok {
|
||||
return model
|
||||
}
|
||||
// Check common model field names
|
||||
if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String {
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// For Gemini requests, model is in the URL path
|
||||
|
||||
73
internal/api/modules/amp/fallback_handlers_test.go
Normal file
73
internal/api/modules/amp/fallback_handlers_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{
|
||||
{ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-amp-fallback")
|
||||
|
||||
mapper := NewModelMapper([]config.AmpModelMapping{
|
||||
{From: "gpt-5.2", To: "test/gpt-5.2"},
|
||||
})
|
||||
|
||||
fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil)
|
||||
|
||||
handler := func(c *gin.Context) {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"model": req.Model,
|
||||
"seen_model": req.Model,
|
||||
})
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/chat/completions", fallback.WrapHandler(handler))
|
||||
|
||||
reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Model string `json:"model"`
|
||||
SeenModel string `json:"seen_model"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Failed to parse response JSON: %v", err)
|
||||
}
|
||||
|
||||
if resp.Model != "gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model)
|
||||
}
|
||||
if resp.SeenModel != "test/gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel)
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||
)
|
||||
|
||||
// createGeminiBridgeHandler creates a handler that bridges AMP CLI's non-standard Gemini paths
|
||||
@@ -15,16 +14,31 @@ import (
|
||||
//
|
||||
// This extracts the model+method from the AMP path and sets it as the :action parameter
|
||||
// so the standard Gemini handler can process it.
|
||||
func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.HandlerFunc {
|
||||
//
|
||||
// The handler parameter should be a Gemini-compatible handler that expects the :action param.
|
||||
func createGeminiBridgeHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Get the full path from the catch-all parameter
|
||||
path := c.Param("path")
|
||||
|
||||
// Extract model:method from AMP CLI path format
|
||||
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
||||
if idx := strings.Index(path, "/models/"); idx >= 0 {
|
||||
// Extract everything after "/models/"
|
||||
actionPart := path[idx+8:] // Skip "/models/"
|
||||
const modelsPrefix = "/models/"
|
||||
if idx := strings.Index(path, modelsPrefix); idx >= 0 {
|
||||
// Extract everything after modelsPrefix
|
||||
actionPart := path[idx+len(modelsPrefix):]
|
||||
|
||||
// Check if model was mapped by FallbackHandler
|
||||
if mappedModel, exists := c.Get(MappedModelContextKey); exists {
|
||||
if strModel, ok := mappedModel.(string); ok && strModel != "" {
|
||||
// Replace the model part in the action
|
||||
// actionPart is like "model-name:method"
|
||||
if colonIdx := strings.Index(actionPart, ":"); colonIdx > 0 {
|
||||
method := actionPart[colonIdx:] // ":method"
|
||||
actionPart = strModel + method
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set this as the :action parameter that the Gemini handler expects
|
||||
c.Params = append(c.Params, gin.Param{
|
||||
@@ -32,8 +46,8 @@ func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.Handl
|
||||
Value: actionPart,
|
||||
})
|
||||
|
||||
// Call the standard Gemini handler
|
||||
geminiHandler.GeminiHandler(c)
|
||||
// Call the handler
|
||||
handler(c)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
93
internal/api/modules/amp/gemini_bridge_test.go
Normal file
93
internal/api/modules/amp/gemini_bridge_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
mappedModel string // empty string means no mapping
|
||||
expectedAction string
|
||||
}{
|
||||
{
|
||||
name: "no_mapping_uses_url_model",
|
||||
path: "/publishers/google/models/gemini-pro:generateContent",
|
||||
mappedModel: "",
|
||||
expectedAction: "gemini-pro:generateContent",
|
||||
},
|
||||
{
|
||||
name: "mapped_model_replaces_url_model",
|
||||
path: "/publishers/google/models/gemini-exp:generateContent",
|
||||
mappedModel: "gemini-2.0-flash",
|
||||
expectedAction: "gemini-2.0-flash:generateContent",
|
||||
},
|
||||
{
|
||||
name: "mapping_preserves_method",
|
||||
path: "/publishers/google/models/gemini-2.5-preview:streamGenerateContent",
|
||||
mappedModel: "gemini-flash",
|
||||
expectedAction: "gemini-flash:streamGenerateContent",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedAction string
|
||||
|
||||
mockGeminiHandler := func(c *gin.Context) {
|
||||
capturedAction = c.Param("action")
|
||||
c.JSON(http.StatusOK, gin.H{"captured": capturedAction})
|
||||
}
|
||||
|
||||
// Use the actual createGeminiBridgeHandler function
|
||||
bridgeHandler := createGeminiBridgeHandler(mockGeminiHandler)
|
||||
|
||||
r := gin.New()
|
||||
if tt.mappedModel != "" {
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set(MappedModelContextKey, tt.mappedModel)
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1"+tt.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
if capturedAction != tt.expectedAction {
|
||||
t.Errorf("Expected action '%s', got '%s'", tt.expectedAction, capturedAction)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateGeminiBridgeHandler_InvalidPath(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mockHandler := func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
bridgeHandler := createGeminiBridgeHandler(mockHandler)
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/invalid/path", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400 for invalid path, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
@@ -3,10 +3,12 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -26,13 +28,15 @@ type ModelMapper interface {
|
||||
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
|
||||
type DefaultModelMapper struct {
|
||||
mu sync.RWMutex
|
||||
mappings map[string]string // from -> to (normalized lowercase keys)
|
||||
mappings map[string]string // exact: from -> to (normalized lowercase keys)
|
||||
regexps []regexMapping // regex rules evaluated in order
|
||||
}
|
||||
|
||||
// NewModelMapper creates a new model mapper with the given initial mappings.
|
||||
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
||||
m := &DefaultModelMapper{
|
||||
mappings: make(map[string]string),
|
||||
regexps: nil,
|
||||
}
|
||||
m.UpdateMappings(mappings)
|
||||
return m
|
||||
@@ -41,6 +45,11 @@ func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
||||
// MapModel checks if a mapping exists for the requested model and if the
|
||||
// target model has available local providers. Returns the mapped model name
|
||||
// or empty string if no valid mapping exists.
|
||||
//
|
||||
// If the requested model contains a thinking suffix (e.g., "g25p(8192)"),
|
||||
// the suffix is preserved in the returned model name (e.g., "gemini-2.5-pro(8192)").
|
||||
// However, if the mapping target already contains a suffix, the config suffix
|
||||
// takes priority over the user's suffix.
|
||||
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||
if requestedModel == "" {
|
||||
return ""
|
||||
@@ -49,24 +58,53 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
// Normalize the requested model for lookup
|
||||
normalizedRequest := strings.ToLower(strings.TrimSpace(requestedModel))
|
||||
// Extract thinking suffix from requested model using ParseSuffix
|
||||
requestResult := thinking.ParseSuffix(requestedModel)
|
||||
baseModel := requestResult.ModelName
|
||||
|
||||
// Check for direct mapping
|
||||
targetModel, exists := m.mappings[normalizedRequest]
|
||||
// Normalize the base model for lookup (case-insensitive)
|
||||
normalizedBase := strings.ToLower(strings.TrimSpace(baseModel))
|
||||
|
||||
// Check for direct mapping using base model name
|
||||
targetModel, exists := m.mappings[normalizedBase]
|
||||
if !exists {
|
||||
return ""
|
||||
// Try regex mappings in order using base model only
|
||||
// (suffix is handled separately via ParseSuffix)
|
||||
for _, rm := range m.regexps {
|
||||
if rm.re.MatchString(baseModel) {
|
||||
targetModel = rm.to
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// Verify target model has available providers
|
||||
providers := util.GetProviderName(targetModel)
|
||||
// Check if target model already has a thinking suffix (config priority)
|
||||
targetResult := thinking.ParseSuffix(targetModel)
|
||||
|
||||
// Verify target model has available providers (use base model for lookup)
|
||||
providers := util.GetProviderName(targetResult.ModelName)
|
||||
if len(providers) == 0 {
|
||||
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
|
||||
return ""
|
||||
}
|
||||
|
||||
// Suffix handling: config suffix takes priority, otherwise preserve user suffix
|
||||
if targetResult.HasSuffix {
|
||||
// Config's "to" already contains a suffix - use it as-is (config priority)
|
||||
return targetModel
|
||||
}
|
||||
|
||||
// Preserve user's thinking suffix on the mapped model
|
||||
// (skip empty suffixes to avoid returning "model()")
|
||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||
return targetModel + "(" + requestResult.RawSuffix + ")"
|
||||
}
|
||||
|
||||
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
|
||||
log.Debugf("amp model mapping: resolved %s -> %s", requestedModel, targetModel)
|
||||
return targetModel
|
||||
}
|
||||
|
||||
@@ -78,6 +116,7 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
|
||||
|
||||
// Clear and rebuild mappings
|
||||
m.mappings = make(map[string]string, len(mappings))
|
||||
m.regexps = make([]regexMapping, 0, len(mappings))
|
||||
|
||||
for _, mapping := range mappings {
|
||||
from := strings.TrimSpace(mapping.From)
|
||||
@@ -88,16 +127,30 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Store with normalized lowercase key for case-insensitive lookup
|
||||
normalizedFrom := strings.ToLower(from)
|
||||
m.mappings[normalizedFrom] = to
|
||||
|
||||
log.Debugf("amp model mapping registered: %s -> %s", from, to)
|
||||
if mapping.Regex {
|
||||
// Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups
|
||||
pattern := "(?i)" + from
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
log.Warnf("amp model mapping: invalid regex %q: %v", from, err)
|
||||
continue
|
||||
}
|
||||
m.regexps = append(m.regexps, regexMapping{re: re, to: to})
|
||||
log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to)
|
||||
} else {
|
||||
// Store with normalized lowercase key for case-insensitive lookup
|
||||
normalizedFrom := strings.ToLower(from)
|
||||
m.mappings[normalizedFrom] = to
|
||||
log.Debugf("amp model mapping registered: %s -> %s", from, to)
|
||||
}
|
||||
}
|
||||
|
||||
if len(m.mappings) > 0 {
|
||||
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
|
||||
}
|
||||
if n := len(m.regexps); n > 0 {
|
||||
log.Infof("amp model mapping: loaded %d regex mapping(s)", n)
|
||||
}
|
||||
}
|
||||
|
||||
// GetMappings returns a copy of current mappings (for debugging/status).
|
||||
@@ -111,3 +164,8 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type regexMapping struct {
|
||||
re *regexp.Regexp
|
||||
to string
|
||||
}
|
||||
|
||||
@@ -71,6 +71,25 @@ func TestModelMapper_MapModel_WithProvider(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_MapModel_TargetWithThinkingSuffix(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-thinking", "codex", []*registry.ModelInfo{
|
||||
{ID: "gpt-5.2", OwnedBy: "openai", Type: "codex"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-thinking")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "gpt-5.2-alias", To: "gpt-5.2(xhigh)"},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
result := mapper.MapModel("gpt-5.2-alias")
|
||||
if result != "gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected gpt-5.2(xhigh), got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{
|
||||
@@ -152,9 +171,9 @@ func TestModelMapper_UpdateMappings_SkipsInvalid(t *testing.T) {
|
||||
mapper := NewModelMapper(nil)
|
||||
|
||||
mapper.UpdateMappings([]config.AmpModelMapping{
|
||||
{From: "", To: "model-b"}, // Invalid: empty from
|
||||
{From: "model-a", To: ""}, // Invalid: empty to
|
||||
{From: " ", To: "model-b"}, // Invalid: whitespace from
|
||||
{From: "", To: "model-b"}, // Invalid: empty from
|
||||
{From: "model-a", To: ""}, // Invalid: empty to
|
||||
{From: " ", To: "model-b"}, // Invalid: whitespace from
|
||||
{From: "model-c", To: "model-d"}, // Valid
|
||||
})
|
||||
|
||||
@@ -184,3 +203,173 @@ func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) {
|
||||
t.Error("Original map was modified")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-regex-1", "gemini", []*registry.ModelInfo{
|
||||
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-regex-1")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "^gpt-5$", To: "gemini-2.5-pro", Regex: true},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
// Incoming model has reasoning suffix, regex matches base, suffix is preserved
|
||||
result := mapper.MapModel("gpt-5(high)")
|
||||
if result != "gemini-2.5-pro(high)" {
|
||||
t.Errorf("Expected gemini-2.5-pro(high), got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_Regex_ExactPrecedence(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-regex-2", "claude", []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
||||
})
|
||||
reg.RegisterClient("test-client-regex-3", "gemini", []*registry.ModelInfo{
|
||||
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-regex-2")
|
||||
defer reg.UnregisterClient("test-client-regex-3")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "gpt-5", To: "claude-sonnet-4"}, // exact
|
||||
{From: "^gpt-5.*$", To: "gemini-2.5-pro", Regex: true}, // regex
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
// Exact match should win over regex
|
||||
result := mapper.MapModel("gpt-5")
|
||||
if result != "claude-sonnet-4" {
|
||||
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_Regex_InvalidPattern_Skipped(t *testing.T) {
|
||||
// Invalid regex should be skipped and not cause panic
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "(", To: "target", Regex: true},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
result := mapper.MapModel("anything")
|
||||
if result != "" {
|
||||
t.Errorf("Expected empty result due to invalid regex, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_Regex_CaseInsensitive(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-regex-4", "claude", []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-regex-4")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "^CLAUDE-OPUS-.*$", To: "claude-sonnet-4", Regex: true},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
result := mapper.MapModel("claude-opus-4.5")
|
||||
if result != "claude-sonnet-4" {
|
||||
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_SuffixPreservation(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
|
||||
// Register test models
|
||||
reg.RegisterClient("test-client-suffix", "gemini", []*registry.ModelInfo{
|
||||
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
|
||||
})
|
||||
reg.RegisterClient("test-client-suffix-2", "claude", []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-suffix")
|
||||
defer reg.UnregisterClient("test-client-suffix-2")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mappings []config.AmpModelMapping
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "numeric suffix preserved",
|
||||
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||
input: "g25p(8192)",
|
||||
want: "gemini-2.5-pro(8192)",
|
||||
},
|
||||
{
|
||||
name: "level suffix preserved",
|
||||
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||
input: "g25p(high)",
|
||||
want: "gemini-2.5-pro(high)",
|
||||
},
|
||||
{
|
||||
name: "no suffix unchanged",
|
||||
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||
input: "g25p",
|
||||
want: "gemini-2.5-pro",
|
||||
},
|
||||
{
|
||||
name: "config suffix takes priority",
|
||||
mappings: []config.AmpModelMapping{{From: "alias", To: "gemini-2.5-pro(medium)"}},
|
||||
input: "alias(high)",
|
||||
want: "gemini-2.5-pro(medium)",
|
||||
},
|
||||
{
|
||||
name: "regex with suffix preserved",
|
||||
mappings: []config.AmpModelMapping{{From: "^g25.*", To: "gemini-2.5-pro", Regex: true}},
|
||||
input: "g25p(8192)",
|
||||
want: "gemini-2.5-pro(8192)",
|
||||
},
|
||||
{
|
||||
name: "auto suffix preserved",
|
||||
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||
input: "g25p(auto)",
|
||||
want: "gemini-2.5-pro(auto)",
|
||||
},
|
||||
{
|
||||
name: "none suffix preserved",
|
||||
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||
input: "g25p(none)",
|
||||
want: "gemini-2.5-pro(none)",
|
||||
},
|
||||
{
|
||||
name: "case insensitive base lookup with suffix",
|
||||
mappings: []config.AmpModelMapping{{From: "G25P", To: "gemini-2.5-pro"}},
|
||||
input: "g25p(high)",
|
||||
want: "gemini-2.5-pro(high)",
|
||||
},
|
||||
{
|
||||
name: "empty suffix filtered out",
|
||||
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
|
||||
input: "g25p()",
|
||||
want: "gemini-2.5-pro",
|
||||
},
|
||||
{
|
||||
name: "incomplete suffix treated as no suffix",
|
||||
mappings: []config.AmpModelMapping{{From: "g25p(high", To: "gemini-2.5-pro"}},
|
||||
input: "g25p(high",
|
||||
want: "gemini-2.5-pro",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mapper := NewModelMapper(tt.mappings)
|
||||
got := mapper.MapModel(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("MapModel(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,8 +3,11 @@ package amp
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
@@ -12,9 +15,37 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func removeQueryValuesMatching(req *http.Request, key string, match string) {
|
||||
if req == nil || req.URL == nil || match == "" {
|
||||
return
|
||||
}
|
||||
|
||||
q := req.URL.Query()
|
||||
values, ok := q[key]
|
||||
if !ok || len(values) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
kept := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
if v == match {
|
||||
continue
|
||||
}
|
||||
kept = append(kept, v)
|
||||
}
|
||||
|
||||
if len(kept) == 0 {
|
||||
q.Del(key)
|
||||
} else {
|
||||
q[key] = kept
|
||||
}
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
// readCloser wraps a reader and forwards Close to a separate closer.
|
||||
// Used to restore peeked bytes while preserving upstream body Close behavior.
|
||||
type readCloser struct {
|
||||
@@ -41,6 +72,22 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
originalDirector(req)
|
||||
req.Host = parsed.Host
|
||||
|
||||
// Remove client's Authorization header - it was only used for CLI Proxy API authentication
|
||||
// We will set our own Authorization using the configured upstream-api-key
|
||||
req.Header.Del("Authorization")
|
||||
req.Header.Del("X-Api-Key")
|
||||
req.Header.Del("X-Goog-Api-Key")
|
||||
|
||||
// Remove proxy, client identity, and browser fingerprint headers
|
||||
misc.ScrubProxyAndFingerprintHeaders(req)
|
||||
|
||||
// Remove query-based credentials if they match the authenticated client API key.
|
||||
// This prevents leaking client auth material to the Amp upstream while avoiding
|
||||
// breaking unrelated upstream query parameters.
|
||||
clientKey := getClientAPIKeyFromContext(req.Context())
|
||||
removeQueryValuesMatching(req, "key", clientKey)
|
||||
removeQueryValuesMatching(req, "auth_token", clientKey)
|
||||
|
||||
// Preserve correlation headers for debugging
|
||||
if req.Header.Get("X-Request-ID") == "" {
|
||||
// Could generate one here if needed
|
||||
@@ -50,7 +97,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
// Users going through ampcode.com proxy are paying for the service and should get all features
|
||||
// including 1M context window (context-1m-2025-08-07)
|
||||
|
||||
// Inject API key from secret source (precedence: config > env > file)
|
||||
// Inject API key from secret source (only uses upstream-api-key from config)
|
||||
if key, err := secretSource.Get(req.Context()); err == nil && key != "" {
|
||||
req.Header.Set("X-Api-Key", key)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||
@@ -62,7 +109,15 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
// Modify incoming responses to handle gzip without Content-Encoding
|
||||
// This addresses the same issue as inline handler gzip handling, but at the proxy level
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// Only process successful responses
|
||||
// Log upstream error responses for diagnostics (502, 503, etc.)
|
||||
// These are NOT proxy connection errors - the upstream responded with an error status
|
||||
if resp.StatusCode >= 500 {
|
||||
log.Errorf("amp upstream responded with error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
|
||||
} else if resp.StatusCode >= 400 {
|
||||
log.Warnf("amp upstream responded with client error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
|
||||
}
|
||||
|
||||
// Only process successful responses for gzip decompression
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil
|
||||
}
|
||||
@@ -146,9 +201,29 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
return nil
|
||||
}
|
||||
|
||||
// Error handler for proxy failures
|
||||
// Error handler for proxy failures with detailed error classification for diagnostics
|
||||
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err)
|
||||
// Classify the error type for better diagnostics
|
||||
var errType string
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
errType = "timeout"
|
||||
} else if errors.Is(err, context.Canceled) {
|
||||
errType = "canceled"
|
||||
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
errType = "dial_timeout"
|
||||
} else if _, ok := err.(net.Error); ok {
|
||||
errType = "network_error"
|
||||
} else {
|
||||
errType = "connection_error"
|
||||
}
|
||||
|
||||
// Don't log as error for context canceled - it's usually client closing connection
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
} else {
|
||||
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
_, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`))
|
||||
|
||||
@@ -3,11 +3,15 @@ package amp
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// Helper: compress data with gzip
|
||||
@@ -306,6 +310,159 @@ func TestReverseProxy_EmptySecret(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) {
|
||||
type captured struct {
|
||||
headers http.Header
|
||||
query string
|
||||
}
|
||||
got := make(chan captured, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery}
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simulate clientAPIKeyMiddleware injection (per-request)
|
||||
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key")
|
||||
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer client-key")
|
||||
req.Header.Set("X-Api-Key", "client-key")
|
||||
req.Header.Set("X-Goog-Api-Key", "client-key")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
c := <-got
|
||||
|
||||
// These are client-provided credentials and must not reach the upstream.
|
||||
if v := c.headers.Get("X-Goog-Api-Key"); v != "" {
|
||||
t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v)
|
||||
}
|
||||
|
||||
// We inject upstream Authorization/X-Api-Key, so the client auth must not survive.
|
||||
if v := c.headers.Get("Authorization"); v != "Bearer upstream" {
|
||||
t.Fatalf("Authorization should be upstream-injected, got: %q", v)
|
||||
}
|
||||
if v := c.headers.Get("X-Api-Key"); v != "upstream" {
|
||||
t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v)
|
||||
}
|
||||
|
||||
// Query-based credentials should be stripped only when they match the authenticated client key.
|
||||
// Should keep unrelated values and parameters.
|
||||
if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") {
|
||||
t.Fatalf("query credentials should be stripped, got raw query: %q", c.query)
|
||||
}
|
||||
if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") {
|
||||
t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) {
|
||||
gotHeaders := make(chan http.Header, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotHeaders <- r.Header.Clone()
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
mapped := NewMappedSecretSource(defaultSource)
|
||||
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, mapped)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simulate clientAPIKeyMiddleware injection (per-request)
|
||||
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1")
|
||||
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
hdr := <-gotHeaders
|
||||
if hdr.Get("X-Api-Key") != "u1" {
|
||||
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
|
||||
}
|
||||
if hdr.Get("Authorization") != "Bearer u1" {
|
||||
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) {
|
||||
gotHeaders := make(chan http.Header, 1)
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotHeaders <- r.Header.Clone()
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(`ok`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
mapped := NewMappedSecretSource(defaultSource)
|
||||
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
proxy, err := createReverseProxy(upstream.URL, mapped)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2")
|
||||
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
res, err := http.Get(srv.URL + "/test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
hdr := <-gotHeaders
|
||||
if hdr.Get("X-Api-Key") != "default" {
|
||||
t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key"))
|
||||
}
|
||||
if hdr.Get("Authorization") != "Bearer default" {
|
||||
t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_ErrorHandler(t *testing.T) {
|
||||
// Point proxy to a non-routable address to trigger error
|
||||
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))
|
||||
@@ -336,6 +493,30 @@ func TestReverseProxy_ErrorHandler(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_ErrorHandler_ContextCanceled(t *testing.T) {
|
||||
// Test that context.Canceled errors return 499 without generic error response
|
||||
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource(""))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a canceled context to trigger the cancellation path
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil).WithContext(ctx)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Directly invoke the ErrorHandler with context.Canceled
|
||||
proxy.ErrorHandler(rr, req, context.Canceled)
|
||||
|
||||
// Body should be empty for canceled requests (no JSON error response)
|
||||
body := rr.Body.Bytes()
|
||||
if len(body) > 0 {
|
||||
t.Fatalf("expected empty body for canceled context, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
|
||||
// Upstream returns gzipped JSON without Content-Encoding header
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
183
internal/api/modules/amp/response_rewriter.go
Normal file
183
internal/api/modules/amp/response_rewriter.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
|
||||
// It's used to rewrite model names in responses when model mapping is used
|
||||
type ResponseRewriter struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
originalModel string
|
||||
isStreaming bool
|
||||
}
|
||||
|
||||
// NewResponseRewriter creates a new response rewriter for model name substitution
|
||||
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
|
||||
return &ResponseRewriter{
|
||||
ResponseWriter: w,
|
||||
body: &bytes.Buffer{},
|
||||
originalModel: originalModel,
|
||||
}
|
||||
}
|
||||
|
||||
const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
|
||||
|
||||
func looksLikeSSEChunk(data []byte) bool {
|
||||
// Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
|
||||
// Heuristics are intentionally simple and cheap.
|
||||
return bytes.Contains(data, []byte("data:")) ||
|
||||
bytes.Contains(data, []byte("event:")) ||
|
||||
bytes.Contains(data, []byte("message_start")) ||
|
||||
bytes.Contains(data, []byte("message_delta")) ||
|
||||
bytes.Contains(data, []byte("content_block_start")) ||
|
||||
bytes.Contains(data, []byte("content_block_delta")) ||
|
||||
bytes.Contains(data, []byte("content_block_stop")) ||
|
||||
bytes.Contains(data, []byte("\n\n"))
|
||||
}
|
||||
|
||||
func (rw *ResponseRewriter) enableStreaming(reason string) error {
|
||||
if rw.isStreaming {
|
||||
return nil
|
||||
}
|
||||
rw.isStreaming = true
|
||||
|
||||
// Flush any previously buffered data to avoid reordering or data loss.
|
||||
if rw.body != nil && rw.body.Len() > 0 {
|
||||
buf := rw.body.Bytes()
|
||||
// Copy before Reset() to keep bytes stable.
|
||||
toFlush := make([]byte, len(buf))
|
||||
copy(toFlush, buf)
|
||||
rw.body.Reset()
|
||||
|
||||
if _, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(toFlush)); err != nil {
|
||||
return err
|
||||
}
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("amp response rewriter: switched to streaming (%s)", reason)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write intercepts response writes and buffers them for model name replacement
|
||||
func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
||||
// Detect streaming on first write (header-based)
|
||||
if !rw.isStreaming && rw.body.Len() == 0 {
|
||||
contentType := rw.Header().Get("Content-Type")
|
||||
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
|
||||
strings.Contains(contentType, "stream")
|
||||
}
|
||||
|
||||
if !rw.isStreaming {
|
||||
// Content-based fallback: detect SSE-like chunks even if Content-Type is missing/wrong.
|
||||
if looksLikeSSEChunk(data) {
|
||||
if err := rw.enableStreaming("sse heuristic"); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else if rw.body.Len()+len(data) > maxBufferedResponseBytes {
|
||||
// Safety cap: avoid unbounded buffering on large responses.
|
||||
log.Warnf("amp response rewriter: buffer exceeded %d bytes, switching to streaming", maxBufferedResponseBytes)
|
||||
if err := rw.enableStreaming("buffer limit"); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rw.isStreaming {
|
||||
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||
if err == nil {
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
return rw.body.Write(data)
|
||||
}
|
||||
|
||||
// Flush writes the buffered response with model names rewritten
|
||||
func (rw *ResponseRewriter) Flush() {
|
||||
if rw.isStreaming {
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
if rw.body.Len() > 0 {
|
||||
if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil {
|
||||
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// modelFieldPaths lists all JSON paths where model name may appear
|
||||
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
||||
|
||||
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
||||
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
|
||||
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
|
||||
// The Amp client struggles when both thinking and tool_use blocks are present
|
||||
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
|
||||
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
|
||||
if filtered.Exists() {
|
||||
originalCount := gjson.GetBytes(data, "content.#").Int()
|
||||
filteredCount := filtered.Get("#").Int()
|
||||
|
||||
if originalCount > filteredCount {
|
||||
var err error
|
||||
data, err = sjson.SetBytes(data, "content", filtered.Value())
|
||||
if err != nil {
|
||||
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
|
||||
} else {
|
||||
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
|
||||
// Log the result for verification
|
||||
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rw.originalModel == "" {
|
||||
return data
|
||||
}
|
||||
for _, path := range modelFieldPaths {
|
||||
if gjson.GetBytes(data, path).Exists() {
|
||||
data, _ = sjson.SetBytes(data, path, rw.originalModel)
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// rewriteStreamChunk rewrites model names in SSE stream chunks
|
||||
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
||||
if rw.originalModel == "" {
|
||||
return chunk
|
||||
}
|
||||
|
||||
// SSE format: "data: {json}\n\n"
|
||||
lines := bytes.Split(chunk, []byte("\n"))
|
||||
for i, line := range lines {
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
jsonData := bytes.TrimPrefix(line, []byte("data: "))
|
||||
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||
// Rewrite JSON in the data line
|
||||
rewritten := rw.rewriteModelInResponse(jsonData)
|
||||
lines[i] = append([]byte("data: "), rewritten...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bytes.Join(lines, []byte("\n"))
|
||||
}
|
||||
110
internal/api/modules/amp/response_rewriter_test.go
Normal file
110
internal/api/modules/amp/response_rewriter_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRewriteModelInResponse_TopLevel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"id":"resp_1","model":"gpt-5.3-codex","output":[]}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
expected := `{"id":"resp_1","model":"gpt-5.2-codex","output":[]}`
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_ResponseModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"completed"}}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
expected := `{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"completed"}}`
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_ResponseCreated(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"in_progress"}}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
expected := `{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"in_progress"}}`
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_NoModelField(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"type":"response.output_item.added","item":{"id":"item_1","type":"message"}}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
if string(result) != string(input) {
|
||||
t.Errorf("expected no modification, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_EmptyOriginalModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: ""}
|
||||
|
||||
input := []byte(`{"model":"gpt-5.3-codex"}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
if string(result) != string(input) {
|
||||
t.Errorf("expected no modification when originalModel is empty, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteStreamChunk_SSEWithResponseModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
chunk := []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.3-codex\",\"status\":\"completed\"}}\n\n")
|
||||
result := rw.rewriteStreamChunk(chunk)
|
||||
|
||||
expected := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.2-codex\",\"status\":\"completed\"}}\n\n"
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteStreamChunk_MultipleEvents(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
chunk := []byte("data: {\"type\":\"response.created\",\"response\":{\"model\":\"gpt-5.3-codex\"}}\n\ndata: {\"type\":\"response.output_item.added\",\"item\":{\"id\":\"item_1\"}}\n\n")
|
||||
result := rw.rewriteStreamChunk(chunk)
|
||||
|
||||
if string(result) == string(chunk) {
|
||||
t.Error("expected response.model to be rewritten in SSE stream")
|
||||
}
|
||||
if !contains(result, []byte(`"model":"gpt-5.2-codex"`)) {
|
||||
t.Errorf("expected rewritten model in output, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteStreamChunk_MessageModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "claude-opus-4.5"}
|
||||
|
||||
chunk := []byte("data: {\"message\":{\"model\":\"claude-sonnet-4\",\"role\":\"assistant\"}}\n\n")
|
||||
result := rw.rewriteStreamChunk(chunk)
|
||||
|
||||
expected := "data: {\"message\":{\"model\":\"claude-opus-4.5\",\"role\":\"assistant\"}}\n\n"
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func contains(data, substr []byte) bool {
|
||||
for i := 0; i <= len(data)-len(substr); i++ {
|
||||
if string(data[i:i+len(substr)]) == string(substr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,12 +1,15 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||
@@ -14,15 +17,47 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// localhostOnlyMiddleware restricts access to localhost (127.0.0.1, ::1) only.
|
||||
// Returns 403 Forbidden for non-localhost clients.
|
||||
//
|
||||
// Security: Uses RemoteAddr (actual TCP connection) instead of ClientIP() to prevent
|
||||
// header spoofing attacks via X-Forwarded-For or similar headers. This means the
|
||||
// middleware will not work correctly behind reverse proxies - users deploying behind
|
||||
// nginx/Cloudflare should disable this feature and use firewall rules instead.
|
||||
func localhostOnlyMiddleware() gin.HandlerFunc {
|
||||
// clientAPIKeyContextKey is the context key used to pass the client API key
|
||||
// from gin.Context to the request context for SecretSource lookup.
|
||||
type clientAPIKeyContextKey struct{}
|
||||
|
||||
// clientAPIKeyMiddleware injects the authenticated client API key from gin.Context["apiKey"]
|
||||
// into the request context so that SecretSource can look it up for per-client upstream routing.
|
||||
func clientAPIKeyMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Extract the client API key from gin context (set by AuthMiddleware)
|
||||
if apiKey, exists := c.Get("apiKey"); exists {
|
||||
if keyStr, ok := apiKey.(string); ok && keyStr != "" {
|
||||
// Inject into request context for SecretSource.Get(ctx) to read
|
||||
ctx := context.WithValue(c.Request.Context(), clientAPIKeyContextKey{}, keyStr)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// getClientAPIKeyFromContext retrieves the client API key from request context.
|
||||
// Returns empty string if not present.
|
||||
func getClientAPIKeyFromContext(ctx context.Context) string {
|
||||
if val := ctx.Value(clientAPIKeyContextKey{}); val != nil {
|
||||
if keyStr, ok := val.(string); ok {
|
||||
return keyStr
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's
|
||||
// localhost restriction setting. This allows hot-reload of the restriction without restarting.
|
||||
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Check current setting (hot-reloadable)
|
||||
if !m.IsRestrictedToLocalhost() {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Use actual TCP connection address (RemoteAddr) to prevent header spoofing
|
||||
// This cannot be forged by X-Forwarded-For or other client-controlled headers
|
||||
remoteAddr := c.Request.RemoteAddr
|
||||
@@ -77,21 +112,77 @@ func noCORSMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// managementAvailabilityMiddleware short-circuits management routes when the upstream
|
||||
// proxy is disabled, preventing noisy localhost warnings and accidental exposure.
|
||||
func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if m.getProxy() == nil {
|
||||
logging.SkipGinRequestLogging(c)
|
||||
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{
|
||||
"error": "amp upstream proxy not available",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// wrapManagementAuth skips auth for selected management paths while keeping authentication elsewhere.
|
||||
func wrapManagementAuth(auth gin.HandlerFunc, prefixes ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
for _, prefix := range prefixes {
|
||||
if strings.HasPrefix(path, prefix) && (len(path) == len(prefix) || path[len(prefix)] == '/') {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
auth(c)
|
||||
}
|
||||
}
|
||||
|
||||
// registerManagementRoutes registers Amp management proxy routes
|
||||
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
|
||||
// If restrictToLocalhost is true, routes will only accept connections from 127.0.0.1/::1.
|
||||
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, proxyHandler gin.HandlerFunc, restrictToLocalhost bool) {
|
||||
// Uses dynamic middleware and proxy getter for hot-reload support.
|
||||
// The auth middleware validates Authorization header against configured API keys.
|
||||
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, auth gin.HandlerFunc) {
|
||||
ampAPI := engine.Group("/api")
|
||||
|
||||
// Always disable CORS for management routes to prevent browser-based attacks
|
||||
ampAPI.Use(noCORSMiddleware())
|
||||
ampAPI.Use(m.managementAvailabilityMiddleware(), noCORSMiddleware())
|
||||
|
||||
// Apply localhost-only restriction if configured
|
||||
if restrictToLocalhost {
|
||||
ampAPI.Use(localhostOnlyMiddleware())
|
||||
log.Info("amp management routes restricted to localhost only (CORS disabled)")
|
||||
} else {
|
||||
log.Warn("amp management routes are NOT restricted to localhost - this is insecure!")
|
||||
// Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost())
|
||||
ampAPI.Use(m.localhostOnlyMiddleware())
|
||||
|
||||
// Apply authentication middleware - requires valid API key in Authorization header
|
||||
var authWithBypass gin.HandlerFunc
|
||||
if auth != nil {
|
||||
ampAPI.Use(auth)
|
||||
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
|
||||
}
|
||||
|
||||
// Inject client API key into request context for per-client upstream routing
|
||||
ampAPI.Use(clientAPIKeyMiddleware())
|
||||
|
||||
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
||||
proxyHandler := func(c *gin.Context) {
|
||||
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
||||
// Upstream already wrote the status (often 404) before the client/stream ended.
|
||||
return
|
||||
}
|
||||
panic(rec)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := m.getProxy()
|
||||
if proxy == nil {
|
||||
c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
|
||||
return
|
||||
}
|
||||
proxy.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
|
||||
// Management routes - these are proxied directly to Amp upstream
|
||||
@@ -114,12 +205,22 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
ampAPI.Any("/tab/*path", proxyHandler)
|
||||
|
||||
// Root-level routes that AMP CLI expects without /api prefix
|
||||
// These need the same security middleware as the /api/* routes
|
||||
rootMiddleware := []gin.HandlerFunc{noCORSMiddleware()}
|
||||
if restrictToLocalhost {
|
||||
rootMiddleware = append(rootMiddleware, localhostOnlyMiddleware())
|
||||
// These need the same security middleware as the /api/* routes (dynamic for hot-reload)
|
||||
rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()}
|
||||
if authWithBypass != nil {
|
||||
rootMiddleware = append(rootMiddleware, authWithBypass)
|
||||
}
|
||||
// Add clientAPIKeyMiddleware after auth for per-client upstream routing
|
||||
rootMiddleware = append(rootMiddleware, clientAPIKeyMiddleware())
|
||||
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/docs/*path", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/settings", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/settings/*path", append(rootMiddleware, proxyHandler)...)
|
||||
|
||||
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...)
|
||||
|
||||
// Root-level auth routes for CLI login flow
|
||||
// Amp uses multiple auth routes: /auth/cli-login, /auth/callback, /auth/sign-in, /auth/logout
|
||||
@@ -132,30 +233,22 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
// We bridge these to our standard Gemini handler to enable local OAuth.
|
||||
// If no local OAuth is available, falls back to ampcode.com proxy.
|
||||
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
||||
geminiBridge := createGeminiBridgeHandler(geminiHandlers)
|
||||
geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
return m.proxy
|
||||
})
|
||||
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
|
||||
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||
return m.getProxy()
|
||||
}, m.modelMapper, m.forceModelMappings)
|
||||
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
|
||||
|
||||
// Route POST model calls through Gemini bridge when a local provider exists, otherwise proxy.
|
||||
// Route POST model calls through Gemini bridge with FallbackHandler.
|
||||
// FallbackHandler checks provider -> mapping -> proxy fallback automatically.
|
||||
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
|
||||
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
|
||||
if c.Request.Method == "POST" {
|
||||
// Attempt to extract the model name from the AMP-style path
|
||||
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
||||
modelPart := path[strings.Index(path, "/models/")+len("/models/"):]
|
||||
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
|
||||
modelPart = modelPart[:colonIdx]
|
||||
}
|
||||
if modelPart != "" {
|
||||
normalized, _ := util.NormalizeGeminiThinkingModel(modelPart)
|
||||
// Only handle locally when we have a provider; otherwise fall back to proxy
|
||||
if providers := util.GetProviderName(normalized); len(providers) > 0 {
|
||||
geminiV1Beta1Handler(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
// POST with /models/ path -> use Gemini bridge with fallback handler
|
||||
// FallbackHandler will check provider/mapping and proxy if needed
|
||||
geminiV1Beta1Handler(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
// Non-POST or no local provider available -> proxy upstream
|
||||
@@ -177,17 +270,19 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
|
||||
|
||||
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
|
||||
// Uses lazy evaluation to access proxy (which is created after routes are registered)
|
||||
// Uses m.getProxy() for hot-reload support (proxy can be updated at runtime)
|
||||
// Also includes model mapping support for routing unavailable models to alternatives
|
||||
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||
return m.proxy
|
||||
}, m.modelMapper)
|
||||
return m.getProxy()
|
||||
}, m.modelMapper, m.forceModelMappings)
|
||||
|
||||
// Provider-specific routes under /api/provider/:provider
|
||||
ampProviders := engine.Group("/api/provider")
|
||||
if auth != nil {
|
||||
ampProviders.Use(auth)
|
||||
}
|
||||
// Inject client API key into request context for per-client upstream routing
|
||||
ampProviders.Use(clientAPIKeyMiddleware())
|
||||
|
||||
provider := ampProviders.Group("/:provider")
|
||||
|
||||
@@ -233,7 +328,7 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
v1betaAmp := provider.Group("/v1beta")
|
||||
{
|
||||
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1betaAmp.POST("/models/:action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||
v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,16 +13,28 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
// Spy to track if proxy handler was called
|
||||
proxyCalled := false
|
||||
proxyHandler := func(c *gin.Context) {
|
||||
proxyCalled = true
|
||||
c.String(200, "proxied")
|
||||
// Create module with proxy for testing
|
||||
m := &AmpModule{
|
||||
restrictToLocalhost: false, // disable localhost restriction for tests
|
||||
}
|
||||
|
||||
m := &AmpModule{}
|
||||
// Create a mock proxy that tracks calls
|
||||
proxyCalled := false
|
||||
mockProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxyCalled = true
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("proxied"))
|
||||
}))
|
||||
defer mockProxy.Close()
|
||||
|
||||
// Create real proxy to mock server
|
||||
proxy, _ := createReverseProxy(mockProxy.URL, NewStaticSecretSource(""))
|
||||
m.setProxy(proxy)
|
||||
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
m.registerManagementRoutes(r, base, proxyHandler, false) // false = don't restrict to localhost in tests
|
||||
m.registerManagementRoutes(r, base, nil)
|
||||
srv := httptest.NewServer(r)
|
||||
defer srv.Close()
|
||||
|
||||
managementPaths := []struct {
|
||||
path string
|
||||
@@ -37,13 +49,14 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
{"/api/meta", http.MethodGet},
|
||||
{"/api/telemetry", http.MethodGet},
|
||||
{"/api/threads", http.MethodGet},
|
||||
{"/threads/", http.MethodGet},
|
||||
{"/threads.rss", http.MethodGet}, // Root-level route (no /api prefix)
|
||||
{"/api/otel", http.MethodGet},
|
||||
{"/api/tab", http.MethodGet},
|
||||
{"/api/tab/some/path", http.MethodGet},
|
||||
{"/auth", http.MethodGet}, // Root-level auth route
|
||||
{"/auth/cli-login", http.MethodGet}, // CLI login flow
|
||||
{"/auth/callback", http.MethodGet}, // OAuth callback
|
||||
{"/auth", http.MethodGet}, // Root-level auth route
|
||||
{"/auth/cli-login", http.MethodGet}, // CLI login flow
|
||||
{"/auth/callback", http.MethodGet}, // OAuth callback
|
||||
// Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST
|
||||
{"/api/provider/google/v1beta1/models", http.MethodGet},
|
||||
{"/api/provider/google/v1beta1/models", http.MethodPost},
|
||||
@@ -52,11 +65,17 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
for _, path := range managementPaths {
|
||||
t.Run(path.path, func(t *testing.T) {
|
||||
proxyCalled = false
|
||||
req := httptest.NewRequest(path.method, path.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
req, err := http.NewRequest(path.method, srv.URL+path.path, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build request: %v", err)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if w.Code == http.StatusNotFound {
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
t.Fatalf("route %s not registered", path.path)
|
||||
}
|
||||
if !proxyCalled {
|
||||
@@ -231,8 +250,13 @@ func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
// Apply localhost-only middleware
|
||||
r.Use(localhostOnlyMiddleware())
|
||||
// Create module with localhost restriction enabled
|
||||
m := &AmpModule{
|
||||
restrictToLocalhost: true,
|
||||
}
|
||||
|
||||
// Apply dynamic localhost-only middleware
|
||||
r.Use(m.localhostOnlyMiddleware())
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
@@ -305,3 +329,53 @@ func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalhostOnlyMiddleware_HotReload(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
// Create module with localhost restriction initially enabled
|
||||
m := &AmpModule{
|
||||
restrictToLocalhost: true,
|
||||
}
|
||||
|
||||
// Apply dynamic localhost-only middleware
|
||||
r.Use(m.localhostOnlyMiddleware())
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
// Test 1: Remote IP should be blocked when restriction is enabled
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.100:12345"
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("Expected 403 when restriction enabled, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Test 2: Hot-reload - disable restriction
|
||||
m.setRestrictToLocalhost(false)
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.100:12345"
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected 200 after disabling restriction, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Test 3: Hot-reload - re-enable restriction
|
||||
m.setRestrictToLocalhost(true)
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.100:12345"
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("Expected 403 after re-enabling restriction, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,9 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// SecretSource provides Amp API keys with configurable precedence and caching
|
||||
@@ -139,6 +142,17 @@ func (s *MultiSourceSecret) InvalidateCache() {
|
||||
s.cache = nil
|
||||
}
|
||||
|
||||
// UpdateExplicitKey refreshes the config-provided key and clears cache.
|
||||
func (s *MultiSourceSecret) UpdateExplicitKey(key string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.explicitKey = strings.TrimSpace(key)
|
||||
s.cache = nil
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// StaticSecretSource returns a fixed API key (for testing)
|
||||
type StaticSecretSource struct {
|
||||
key string
|
||||
@@ -153,3 +167,82 @@ func NewStaticSecretSource(key string) *StaticSecretSource {
|
||||
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
|
||||
return s.key, nil
|
||||
}
|
||||
|
||||
// MappedSecretSource wraps a default SecretSource and adds per-client API key mapping.
|
||||
// When a request context contains a client API key that matches a configured mapping,
|
||||
// the corresponding upstream key is returned. Otherwise, falls back to the default source.
|
||||
type MappedSecretSource struct {
|
||||
defaultSource SecretSource
|
||||
mu sync.RWMutex
|
||||
lookup map[string]string // clientKey -> upstreamKey
|
||||
}
|
||||
|
||||
// NewMappedSecretSource creates a MappedSecretSource wrapping the given default source.
|
||||
func NewMappedSecretSource(defaultSource SecretSource) *MappedSecretSource {
|
||||
return &MappedSecretSource{
|
||||
defaultSource: defaultSource,
|
||||
lookup: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves the Amp API key, checking per-client mappings first.
|
||||
// If the request context contains a client API key that matches a configured mapping,
|
||||
// returns the corresponding upstream key. Otherwise, falls back to the default source.
|
||||
func (s *MappedSecretSource) Get(ctx context.Context) (string, error) {
|
||||
// Try to get client API key from request context
|
||||
clientKey := getClientAPIKeyFromContext(ctx)
|
||||
if clientKey != "" {
|
||||
s.mu.RLock()
|
||||
if upstreamKey, ok := s.lookup[clientKey]; ok && upstreamKey != "" {
|
||||
s.mu.RUnlock()
|
||||
return upstreamKey, nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
}
|
||||
|
||||
// Fall back to default source
|
||||
return s.defaultSource.Get(ctx)
|
||||
}
|
||||
|
||||
// UpdateMappings rebuilds the client-to-upstream key mapping from configuration entries.
|
||||
// If the same client key appears in multiple entries, logs a warning and uses the first one.
|
||||
func (s *MappedSecretSource) UpdateMappings(entries []config.AmpUpstreamAPIKeyEntry) {
|
||||
newLookup := make(map[string]string)
|
||||
|
||||
for _, entry := range entries {
|
||||
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
|
||||
if upstreamKey == "" {
|
||||
continue
|
||||
}
|
||||
for _, clientKey := range entry.APIKeys {
|
||||
trimmedKey := strings.TrimSpace(clientKey)
|
||||
if trimmedKey == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := newLookup[trimmedKey]; exists {
|
||||
// Log warning for duplicate client key, first one wins
|
||||
log.Warnf("amp upstream-api-keys: client API key appears in multiple entries; using first mapping.")
|
||||
continue
|
||||
}
|
||||
newLookup[trimmedKey] = upstreamKey
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.lookup = newLookup
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// UpdateDefaultExplicitKey updates the explicit key on the underlying MultiSourceSecret (if applicable).
|
||||
func (s *MappedSecretSource) UpdateDefaultExplicitKey(key string) {
|
||||
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
|
||||
ms.UpdateExplicitKey(key)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateCache invalidates cache on the underlying MultiSourceSecret (if applicable).
|
||||
func (s *MappedSecretSource) InvalidateCache() {
|
||||
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
|
||||
ms.InvalidateCache()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,10 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/sirupsen/logrus/hooks/test"
|
||||
)
|
||||
|
||||
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
|
||||
@@ -278,3 +282,85 @@ func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) {
|
||||
t.Fatalf("after cache expiry, expected new-value, got %q", got3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) {
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
s := NewMappedSecretSource(defaultSource)
|
||||
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "u1" {
|
||||
t.Fatalf("want u1, got %q", got)
|
||||
}
|
||||
|
||||
ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2")
|
||||
got, err = s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "default" {
|
||||
t.Fatalf("want default fallback, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) {
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
s := NewMappedSecretSource(defaultSource)
|
||||
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
{
|
||||
UpstreamAPIKey: "u2",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
|
||||
got, err := s.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "u1" {
|
||||
t.Fatalf("want u1 (first wins), got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) {
|
||||
hook := test.NewLocal(log.StandardLogger())
|
||||
defer hook.Reset()
|
||||
|
||||
defaultSource := NewStaticSecretSource("default")
|
||||
s := NewMappedSecretSource(defaultSource)
|
||||
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||
{
|
||||
UpstreamAPIKey: "u1",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
{
|
||||
UpstreamAPIKey: "u2",
|
||||
APIKeys: []string{"k1"},
|
||||
},
|
||||
})
|
||||
|
||||
foundWarning := false
|
||||
for _, entry := range hook.AllEntries() {
|
||||
if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." {
|
||||
foundWarning = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundWarning {
|
||||
t.Fatal("expected warning log for duplicate client key, but none was found")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -23,6 +24,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
||||
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||
@@ -33,6 +35,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -49,6 +52,7 @@ type serverOptionConfig struct {
|
||||
keepAliveEnabled bool
|
||||
keepAliveTimeout time.Duration
|
||||
keepAliveOnTimeout func()
|
||||
postAuthHook auth.PostAuthHook
|
||||
}
|
||||
|
||||
// ServerOption customises HTTP server construction.
|
||||
@@ -56,10 +60,8 @@ type ServerOption func(*serverOptionConfig)
|
||||
|
||||
func defaultRequestLoggerFactory(cfg *config.Config, configPath string) logging.RequestLogger {
|
||||
configDir := filepath.Dir(configPath)
|
||||
if base := util.WritablePath(); base != "" {
|
||||
return logging.NewFileRequestLogger(cfg.RequestLog, filepath.Join(base, "logs"), configDir)
|
||||
}
|
||||
return logging.NewFileRequestLogger(cfg.RequestLog, "logs", configDir)
|
||||
logsDir := logging.ResolveLogDirectory(cfg)
|
||||
return logging.NewFileRequestLogger(cfg.RequestLog, logsDir, configDir, cfg.ErrorLogsMaxFiles)
|
||||
}
|
||||
|
||||
// WithMiddleware appends additional Gin middleware during server construction.
|
||||
@@ -109,6 +111,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
|
||||
}
|
||||
}
|
||||
|
||||
// WithPostAuthHook registers a hook to be called after auth record creation.
|
||||
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
|
||||
return func(cfg *serverOptionConfig) {
|
||||
cfg.postAuthHook = hook
|
||||
}
|
||||
}
|
||||
|
||||
// Server represents the main API server.
|
||||
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
||||
type Server struct {
|
||||
@@ -209,13 +218,15 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
// Resolve logs directory relative to the configuration file directory.
|
||||
var requestLogger logging.RequestLogger
|
||||
var toggle func(bool)
|
||||
if optionState.requestLoggerFactory != nil {
|
||||
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
|
||||
}
|
||||
if requestLogger != nil {
|
||||
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
|
||||
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
|
||||
toggle = setter.SetEnabled
|
||||
if !cfg.CommercialMode {
|
||||
if optionState.requestLoggerFactory != nil {
|
||||
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
|
||||
}
|
||||
if requestLogger != nil {
|
||||
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
|
||||
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
|
||||
toggle = setter.SetEnabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,13 +241,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
|
||||
|
||||
// Create server instance
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s := &Server{
|
||||
engine: engine,
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames),
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager),
|
||||
cfg: cfg,
|
||||
accessManager: accessManager,
|
||||
requestLogger: requestLogger,
|
||||
@@ -251,7 +258,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||
s.applyAccessConfig(nil, cfg)
|
||||
if authManager != nil {
|
||||
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
||||
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
|
||||
}
|
||||
managementasset.SetCurrentConfig(cfg)
|
||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
@@ -260,11 +267,11 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
if optionState.localPassword != "" {
|
||||
s.mgmt.SetLocalPassword(optionState.localPassword)
|
||||
}
|
||||
logDir := filepath.Join(s.currentPath, "logs")
|
||||
if base := util.WritablePath(); base != "" {
|
||||
logDir = filepath.Join(base, "logs")
|
||||
}
|
||||
logDir := logging.ResolveLogDirectory(cfg)
|
||||
s.mgmt.SetLogDirectory(logDir)
|
||||
if optionState.postAuthHook != nil {
|
||||
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
|
||||
}
|
||||
s.localPassword = optionState.localPassword
|
||||
|
||||
// Setup routes
|
||||
@@ -287,20 +294,26 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
optionState.routerConfigurator(engine, s.handlers, cfg)
|
||||
}
|
||||
|
||||
// Register management routes when configuration or environment secrets are available.
|
||||
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret
|
||||
// Register management routes when configuration or environment secrets are available,
|
||||
// or when a local management password is provided (e.g. TUI mode).
|
||||
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret || s.localPassword != ""
|
||||
s.managementRoutesEnabled.Store(hasManagementSecret)
|
||||
if hasManagementSecret {
|
||||
s.registerManagementRoutes()
|
||||
}
|
||||
|
||||
// === CLIProxyAPIPlus 扩展: 注册 Kiro OAuth Web 路由 ===
|
||||
kiroOAuthHandler := kiro.NewOAuthWebHandler(cfg)
|
||||
kiroOAuthHandler.RegisterRoutes(engine)
|
||||
log.Info("Kiro OAuth Web routes registered at /v0/oauth/kiro/*")
|
||||
|
||||
if optionState.keepAliveEnabled {
|
||||
s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout)
|
||||
}
|
||||
|
||||
// Create HTTP server
|
||||
s.server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", cfg.Port),
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
Handler: engine,
|
||||
}
|
||||
|
||||
@@ -326,7 +339,9 @@ func (s *Server) setupRoutes() {
|
||||
v1.POST("/completions", openaiHandlers.Completions)
|
||||
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
||||
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
|
||||
v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket)
|
||||
v1.POST("/responses", openaiResponsesHandlers.Responses)
|
||||
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
|
||||
}
|
||||
|
||||
// Gemini compatible API routes
|
||||
@@ -334,8 +349,8 @@ func (s *Server) setupRoutes() {
|
||||
v1beta.Use(AuthMiddleware(s.accessManager))
|
||||
{
|
||||
v1beta.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1beta.POST("/models/:action", geminiHandlers.GeminiHandler)
|
||||
v1beta.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||
v1beta.POST("/models/*action", geminiHandlers.GeminiHandler)
|
||||
v1beta.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
|
||||
// Root endpoint
|
||||
@@ -349,6 +364,12 @@ func (s *Server) setupRoutes() {
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
// Event logging endpoint - handles Claude Code telemetry requests
|
||||
// Returns 200 OK to prevent 404 errors in logs
|
||||
s.engine.POST("/api/event_logging/batch", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
|
||||
|
||||
// OAuth callback endpoints (reuse main server port)
|
||||
@@ -358,10 +379,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
// Persist to a temporary file keyed by state
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-anthropic-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "anthropic", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -371,9 +393,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-codex-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "codex", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -383,9 +407,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-gemini-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gemini", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -395,9 +421,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-iflow-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -407,9 +435,25 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-antigravity-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "antigravity", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
})
|
||||
|
||||
s.engine.GET("/kiro/callback", func(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "kiro", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -469,9 +513,12 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
|
||||
{
|
||||
mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
|
||||
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
|
||||
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
|
||||
mgmt.GET("/config", s.mgmt.GetConfig)
|
||||
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
|
||||
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
|
||||
mgmt.GET("/latest-version", s.mgmt.GetLatestVersion)
|
||||
|
||||
mgmt.GET("/debug", s.mgmt.GetDebug)
|
||||
mgmt.PUT("/debug", s.mgmt.PutDebug)
|
||||
@@ -481,6 +528,14 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PUT("/logging-to-file", s.mgmt.PutLoggingToFile)
|
||||
mgmt.PATCH("/logging-to-file", s.mgmt.PutLoggingToFile)
|
||||
|
||||
mgmt.GET("/logs-max-total-size-mb", s.mgmt.GetLogsMaxTotalSizeMB)
|
||||
mgmt.PUT("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
|
||||
mgmt.PATCH("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
|
||||
|
||||
mgmt.GET("/error-logs-max-files", s.mgmt.GetErrorLogsMaxFiles)
|
||||
mgmt.PUT("/error-logs-max-files", s.mgmt.PutErrorLogsMaxFiles)
|
||||
mgmt.PATCH("/error-logs-max-files", s.mgmt.PutErrorLogsMaxFiles)
|
||||
|
||||
mgmt.GET("/usage-statistics-enabled", s.mgmt.GetUsageStatisticsEnabled)
|
||||
mgmt.PUT("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
||||
mgmt.PATCH("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
||||
@@ -490,6 +545,8 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL)
|
||||
mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL)
|
||||
|
||||
mgmt.POST("/api-call", s.mgmt.APICall)
|
||||
|
||||
mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject)
|
||||
mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
|
||||
mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
|
||||
@@ -512,6 +569,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
|
||||
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
|
||||
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
|
||||
mgmt.GET("/request-log-by-id/:id", s.mgmt.GetRequestLogByID)
|
||||
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
|
||||
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
|
||||
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
|
||||
@@ -519,6 +577,30 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth)
|
||||
mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth)
|
||||
|
||||
mgmt.GET("/ampcode", s.mgmt.GetAmpCode)
|
||||
mgmt.GET("/ampcode/upstream-url", s.mgmt.GetAmpUpstreamURL)
|
||||
mgmt.PUT("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
|
||||
mgmt.PATCH("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
|
||||
mgmt.DELETE("/ampcode/upstream-url", s.mgmt.DeleteAmpUpstreamURL)
|
||||
mgmt.GET("/ampcode/upstream-api-key", s.mgmt.GetAmpUpstreamAPIKey)
|
||||
mgmt.PUT("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
|
||||
mgmt.PATCH("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
|
||||
mgmt.DELETE("/ampcode/upstream-api-key", s.mgmt.DeleteAmpUpstreamAPIKey)
|
||||
mgmt.GET("/ampcode/restrict-management-to-localhost", s.mgmt.GetAmpRestrictManagementToLocalhost)
|
||||
mgmt.PUT("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
|
||||
mgmt.PATCH("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
|
||||
mgmt.GET("/ampcode/model-mappings", s.mgmt.GetAmpModelMappings)
|
||||
mgmt.PUT("/ampcode/model-mappings", s.mgmt.PutAmpModelMappings)
|
||||
mgmt.PATCH("/ampcode/model-mappings", s.mgmt.PatchAmpModelMappings)
|
||||
mgmt.DELETE("/ampcode/model-mappings", s.mgmt.DeleteAmpModelMappings)
|
||||
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
|
||||
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
||||
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
||||
mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys)
|
||||
mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys)
|
||||
mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys)
|
||||
mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys)
|
||||
|
||||
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
|
||||
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
|
||||
mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry)
|
||||
@@ -526,6 +608,14 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
||||
mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
||||
|
||||
mgmt.GET("/force-model-prefix", s.mgmt.GetForceModelPrefix)
|
||||
mgmt.PUT("/force-model-prefix", s.mgmt.PutForceModelPrefix)
|
||||
mgmt.PATCH("/force-model-prefix", s.mgmt.PutForceModelPrefix)
|
||||
|
||||
mgmt.GET("/routing/strategy", s.mgmt.GetRoutingStrategy)
|
||||
mgmt.PUT("/routing/strategy", s.mgmt.PutRoutingStrategy)
|
||||
mgmt.PATCH("/routing/strategy", s.mgmt.PutRoutingStrategy)
|
||||
|
||||
mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys)
|
||||
mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys)
|
||||
mgmt.PATCH("/claude-api-key", s.mgmt.PatchClaudeKey)
|
||||
@@ -541,15 +631,29 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat)
|
||||
mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat)
|
||||
|
||||
mgmt.GET("/vertex-api-key", s.mgmt.GetVertexCompatKeys)
|
||||
mgmt.PUT("/vertex-api-key", s.mgmt.PutVertexCompatKeys)
|
||||
mgmt.PATCH("/vertex-api-key", s.mgmt.PatchVertexCompatKey)
|
||||
mgmt.DELETE("/vertex-api-key", s.mgmt.DeleteVertexCompatKey)
|
||||
|
||||
mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels)
|
||||
mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels)
|
||||
mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels)
|
||||
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
|
||||
|
||||
mgmt.GET("/oauth-model-alias", s.mgmt.GetOAuthModelAlias)
|
||||
mgmt.PUT("/oauth-model-alias", s.mgmt.PutOAuthModelAlias)
|
||||
mgmt.PATCH("/oauth-model-alias", s.mgmt.PatchOAuthModelAlias)
|
||||
mgmt.DELETE("/oauth-model-alias", s.mgmt.DeleteOAuthModelAlias)
|
||||
|
||||
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
|
||||
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
|
||||
mgmt.GET("/model-definitions/:channel", s.mgmt.GetStaticModelDefinitions)
|
||||
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
||||
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
|
||||
mgmt.PATCH("/auth-files/fields", s.mgmt.PatchAuthFileFields)
|
||||
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
||||
|
||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||
@@ -557,8 +661,13 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||
mgmt.GET("/kilo-auth-url", s.mgmt.RequestKiloToken)
|
||||
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
||||
mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
|
||||
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||
}
|
||||
}
|
||||
@@ -587,14 +696,17 @@ func (s *Server) serveManagementControlPanel(c *gin.Context) {
|
||||
|
||||
if _, err := os.Stat(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL)
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
// Synchronously ensure management.html is available with a detached context.
|
||||
// Control panel bootstrap should not be canceled by client disconnects.
|
||||
if !managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository) {
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
log.WithError(err).Error("failed to stat management control panel asset")
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
log.WithError(err).Error("failed to stat management control panel asset")
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
c.File(filePath)
|
||||
@@ -809,50 +921,35 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
} else if toggler, ok := s.requestLogger.(interface{ SetEnabled(bool) }); ok {
|
||||
toggler.SetEnabled(cfg.RequestLog)
|
||||
}
|
||||
if oldCfg != nil {
|
||||
log.Debugf("request logging updated from %t to %t", previousRequestLog, cfg.RequestLog)
|
||||
} else {
|
||||
log.Debugf("request logging toggled to %t", cfg.RequestLog)
|
||||
}
|
||||
}
|
||||
|
||||
if oldCfg != nil && oldCfg.LoggingToFile != cfg.LoggingToFile {
|
||||
if err := logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
||||
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
||||
if err := logging.ConfigureLogOutput(cfg); err != nil {
|
||||
log.Errorf("failed to reconfigure log output: %v", err)
|
||||
} else {
|
||||
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile)
|
||||
}
|
||||
}
|
||||
|
||||
if oldCfg == nil || oldCfg.UsageStatisticsEnabled != cfg.UsageStatisticsEnabled {
|
||||
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
||||
if oldCfg != nil {
|
||||
log.Debugf("usage_statistics_enabled updated from %t to %t", oldCfg.UsageStatisticsEnabled, cfg.UsageStatisticsEnabled)
|
||||
} else {
|
||||
log.Debugf("usage_statistics_enabled toggled to %t", cfg.UsageStatisticsEnabled)
|
||||
}
|
||||
|
||||
if s.requestLogger != nil && (oldCfg == nil || oldCfg.ErrorLogsMaxFiles != cfg.ErrorLogsMaxFiles) {
|
||||
if setter, ok := s.requestLogger.(interface{ SetErrorLogsMaxFiles(int) }); ok {
|
||||
setter.SetErrorLogsMaxFiles(cfg.ErrorLogsMaxFiles)
|
||||
}
|
||||
}
|
||||
|
||||
if oldCfg == nil || oldCfg.DisableCooling != cfg.DisableCooling {
|
||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
if oldCfg != nil {
|
||||
log.Debugf("disable_cooling updated from %t to %t", oldCfg.DisableCooling, cfg.DisableCooling)
|
||||
} else {
|
||||
log.Debugf("disable_cooling toggled to %t", cfg.DisableCooling)
|
||||
}
|
||||
}
|
||||
|
||||
if s.handlers != nil && s.handlers.AuthManager != nil {
|
||||
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
||||
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
|
||||
}
|
||||
|
||||
// Update log level dynamically when debug flag changes
|
||||
if oldCfg == nil || oldCfg.Debug != cfg.Debug {
|
||||
util.SetLogLevel(cfg)
|
||||
if oldCfg != nil {
|
||||
log.Debugf("debug mode updated from %t to %t", oldCfg.Debug, cfg.Debug)
|
||||
} else {
|
||||
log.Debugf("debug mode toggled to %t", cfg.Debug)
|
||||
}
|
||||
}
|
||||
|
||||
prevSecretEmpty := true
|
||||
@@ -897,35 +994,32 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
// Save YAML snapshot for next comparison
|
||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s.handlers.OpenAICompatProviders = providerNames
|
||||
|
||||
s.handlers.UpdateClients(&cfg.SDKConfig)
|
||||
|
||||
if !cfg.RemoteManagement.DisableControlPanel {
|
||||
staticDir := managementasset.StaticDir(s.configFilePath)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL)
|
||||
}
|
||||
if s.mgmt != nil {
|
||||
s.mgmt.SetConfig(cfg)
|
||||
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
||||
}
|
||||
|
||||
// Notify Amp module of config changes (for model mapping hot-reload)
|
||||
if s.ampModule != nil {
|
||||
log.Debugf("triggering amp module config update")
|
||||
if err := s.ampModule.OnConfigUpdated(cfg); err != nil {
|
||||
log.Errorf("failed to update Amp module config: %v", err)
|
||||
// Notify Amp module only when Amp config has changed.
|
||||
ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode)
|
||||
if ampConfigChanged {
|
||||
if s.ampModule != nil {
|
||||
log.Debugf("triggering amp module config update")
|
||||
if err := s.ampModule.OnConfigUpdated(cfg); err != nil {
|
||||
log.Errorf("failed to update Amp module config: %v", err)
|
||||
}
|
||||
} else {
|
||||
log.Warnf("amp module is nil, skipping config update")
|
||||
}
|
||||
} else {
|
||||
log.Warnf("amp module is nil, skipping config update")
|
||||
}
|
||||
|
||||
// Count client sources from configuration and auth directory
|
||||
authFiles := util.CountAuthFiles(cfg.AuthDir)
|
||||
// Count client sources from configuration and auth store.
|
||||
tokenStore := sdkAuth.GetTokenStore()
|
||||
if dirSetter, ok := tokenStore.(interface{ SetBaseDir(string) }); ok {
|
||||
dirSetter.SetBaseDir(cfg.AuthDir)
|
||||
}
|
||||
authEntries := util.CountAuthFiles(context.Background(), tokenStore)
|
||||
geminiAPIKeyCount := len(cfg.GeminiKey)
|
||||
claudeAPIKeyCount := len(cfg.ClaudeKey)
|
||||
codexAPIKeyCount := len(cfg.CodexKey)
|
||||
@@ -936,10 +1030,10 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
openAICompatCount += len(entry.APIKeyEntries)
|
||||
}
|
||||
|
||||
total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
|
||||
fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
|
||||
total := authEntries + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
|
||||
fmt.Printf("server clients and configuration updated: %d clients (%d auth entries + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
|
||||
total,
|
||||
authFiles,
|
||||
authEntries,
|
||||
geminiAPIKeyCount,
|
||||
claudeAPIKeyCount,
|
||||
codexAPIKeyCount,
|
||||
@@ -980,14 +1074,10 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Missing API key"})
|
||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
|
||||
default:
|
||||
statusCode := err.HTTPStatusCode()
|
||||
if statusCode >= http.StatusInternalServerError {
|
||||
log.Errorf("authentication middleware error: %v", err)
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Authentication service error"})
|
||||
}
|
||||
c.AbortWithStatusJSON(statusCode, gin.H{"error": err.Message})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,9 +7,11 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gin "github.com/gin-gonic/gin"
|
||||
proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
@@ -109,3 +111,100 @@ func TestAmpProviderModelRoutes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
|
||||
t.Setenv("WRITABLE_PATH", "")
|
||||
t.Setenv("writable_path", "")
|
||||
|
||||
originalWD, errGetwd := os.Getwd()
|
||||
if errGetwd != nil {
|
||||
t.Fatalf("failed to get current working directory: %v", errGetwd)
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
if errChdir := os.Chdir(tmpDir); errChdir != nil {
|
||||
t.Fatalf("failed to switch working directory: %v", errChdir)
|
||||
}
|
||||
defer func() {
|
||||
if errChdirBack := os.Chdir(originalWD); errChdirBack != nil {
|
||||
t.Fatalf("failed to restore working directory: %v", errChdirBack)
|
||||
}
|
||||
}()
|
||||
|
||||
// Force ResolveLogDirectory to fallback to auth-dir/logs by making ./logs not a writable directory.
|
||||
if errWriteFile := os.WriteFile(filepath.Join(tmpDir, "logs"), []byte("not-a-directory"), 0o644); errWriteFile != nil {
|
||||
t.Fatalf("failed to create blocking logs file: %v", errWriteFile)
|
||||
}
|
||||
|
||||
configDir := filepath.Join(tmpDir, "config")
|
||||
if errMkdirConfig := os.MkdirAll(configDir, 0o755); errMkdirConfig != nil {
|
||||
t.Fatalf("failed to create config dir: %v", errMkdirConfig)
|
||||
}
|
||||
configPath := filepath.Join(configDir, "config.yaml")
|
||||
|
||||
authDir := filepath.Join(tmpDir, "auth")
|
||||
if errMkdirAuth := os.MkdirAll(authDir, 0o700); errMkdirAuth != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", errMkdirAuth)
|
||||
}
|
||||
|
||||
cfg := &proxyconfig.Config{
|
||||
SDKConfig: proxyconfig.SDKConfig{
|
||||
RequestLog: false,
|
||||
},
|
||||
AuthDir: authDir,
|
||||
ErrorLogsMaxFiles: 10,
|
||||
}
|
||||
|
||||
logger := defaultRequestLoggerFactory(cfg, configPath)
|
||||
fileLogger, ok := logger.(*internallogging.FileRequestLogger)
|
||||
if !ok {
|
||||
t.Fatalf("expected *FileRequestLogger, got %T", logger)
|
||||
}
|
||||
|
||||
errLog := fileLogger.LogRequestWithOptions(
|
||||
"/v1/chat/completions",
|
||||
http.MethodPost,
|
||||
map[string][]string{"Content-Type": []string{"application/json"}},
|
||||
[]byte(`{"input":"hello"}`),
|
||||
http.StatusBadGateway,
|
||||
map[string][]string{"Content-Type": []string{"application/json"}},
|
||||
[]byte(`{"error":"upstream failure"}`),
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
"issue-1711",
|
||||
time.Now(),
|
||||
time.Now(),
|
||||
)
|
||||
if errLog != nil {
|
||||
t.Fatalf("failed to write forced error request log: %v", errLog)
|
||||
}
|
||||
|
||||
authLogsDir := filepath.Join(authDir, "logs")
|
||||
authEntries, errReadAuthDir := os.ReadDir(authLogsDir)
|
||||
if errReadAuthDir != nil {
|
||||
t.Fatalf("failed to read auth logs dir %s: %v", authLogsDir, errReadAuthDir)
|
||||
}
|
||||
foundErrorLogInAuthDir := false
|
||||
for _, entry := range authEntries {
|
||||
if strings.HasPrefix(entry.Name(), "error-") && strings.HasSuffix(entry.Name(), ".log") {
|
||||
foundErrorLogInAuthDir = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundErrorLogInAuthDir {
|
||||
t.Fatalf("expected forced error log in auth fallback dir %s, got entries: %+v", authLogsDir, authEntries)
|
||||
}
|
||||
|
||||
configLogsDir := filepath.Join(configDir, "logs")
|
||||
configEntries, errReadConfigDir := os.ReadDir(configLogsDir)
|
||||
if errReadConfigDir != nil && !os.IsNotExist(errReadConfigDir) {
|
||||
t.Fatalf("failed to inspect config logs dir %s: %v", configLogsDir, errReadConfigDir)
|
||||
}
|
||||
for _, entry := range configEntries {
|
||||
if strings.HasPrefix(entry.Name(), "error-") && strings.HasSuffix(entry.Name(), ".log") {
|
||||
t.Fatalf("unexpected forced error log in config dir %s", configLogsDir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
344
internal/auth/antigravity/auth.go
Normal file
344
internal/auth/antigravity/auth.go
Normal file
@@ -0,0 +1,344 @@
|
||||
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// TokenResponse represents OAuth token response from Google
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
// userInfo represents Google user profile
|
||||
type userInfo struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
// AntigravityAuth handles Antigravity OAuth authentication
|
||||
type AntigravityAuth struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewAntigravityAuth creates a new Antigravity auth service.
|
||||
func NewAntigravityAuth(cfg *config.Config, httpClient *http.Client) *AntigravityAuth {
|
||||
if httpClient != nil {
|
||||
return &AntigravityAuth{httpClient: httpClient}
|
||||
}
|
||||
if cfg == nil {
|
||||
cfg = &config.Config{}
|
||||
}
|
||||
return &AntigravityAuth{
|
||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
|
||||
}
|
||||
}
|
||||
|
||||
// BuildAuthURL generates the OAuth authorization URL.
|
||||
func (o *AntigravityAuth) BuildAuthURL(state, redirectURI string) string {
|
||||
if strings.TrimSpace(redirectURI) == "" {
|
||||
redirectURI = fmt.Sprintf("http://localhost:%d/oauth-callback", CallbackPort)
|
||||
}
|
||||
params := url.Values{}
|
||||
params.Set("access_type", "offline")
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("prompt", "consent")
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("scope", strings.Join(Scopes, " "))
|
||||
params.Set("state", state)
|
||||
return AuthEndpoint + "?" + params.Encode()
|
||||
}
|
||||
|
||||
// ExchangeCodeForTokens exchanges authorization code for access and refresh tokens
|
||||
func (o *AntigravityAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string) (*TokenResponse, error) {
|
||||
data := url.Values{}
|
||||
data.Set("code", code)
|
||||
data.Set("client_id", ClientID)
|
||||
data.Set("client_secret", ClientSecret)
|
||||
data.Set("redirect_uri", redirectURI)
|
||||
data.Set("grant_type", "authorization_code")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("antigravity token exchange: create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, errDo := o.httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return nil, fmt.Errorf("antigravity token exchange: execute request: %w", errDo)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity token exchange: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
|
||||
if errRead != nil {
|
||||
return nil, fmt.Errorf("antigravity token exchange: read response: %w", errRead)
|
||||
}
|
||||
body := strings.TrimSpace(string(bodyBytes))
|
||||
if body == "" {
|
||||
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d", resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var token TokenResponse
|
||||
if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil {
|
||||
return nil, fmt.Errorf("antigravity token exchange: decode response: %w", errDecode)
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// FetchUserInfo retrieves user email from Google
|
||||
func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
|
||||
accessToken = strings.TrimSpace(accessToken)
|
||||
if accessToken == "" {
|
||||
return "", fmt.Errorf("antigravity userinfo: missing access token")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("antigravity userinfo: create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, errDo := o.httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return "", fmt.Errorf("antigravity userinfo: execute request: %w", errDo)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity userinfo: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
|
||||
if errRead != nil {
|
||||
return "", fmt.Errorf("antigravity userinfo: read response: %w", errRead)
|
||||
}
|
||||
body := strings.TrimSpace(string(bodyBytes))
|
||||
if body == "" {
|
||||
return "", fmt.Errorf("antigravity userinfo: request failed: status %d", resp.StatusCode)
|
||||
}
|
||||
return "", fmt.Errorf("antigravity userinfo: request failed: status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
var info userInfo
|
||||
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
|
||||
return "", fmt.Errorf("antigravity userinfo: decode response: %w", errDecode)
|
||||
}
|
||||
email := strings.TrimSpace(info.Email)
|
||||
if email == "" {
|
||||
return "", fmt.Errorf("antigravity userinfo: response missing email")
|
||||
}
|
||||
return email, nil
|
||||
}
|
||||
|
||||
// FetchProjectID retrieves the project ID for the authenticated user via loadCodeAssist
|
||||
func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string) (string, error) {
|
||||
loadReqBody := map[string]any{
|
||||
"metadata": map[string]string{
|
||||
"ideType": "ANTIGRAVITY",
|
||||
"platform": "PLATFORM_UNSPECIFIED",
|
||||
"pluginType": "GEMINI",
|
||||
},
|
||||
}
|
||||
|
||||
rawBody, errMarshal := json.Marshal(loadReqBody)
|
||||
if errMarshal != nil {
|
||||
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
||||
}
|
||||
|
||||
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", APIEndpoint, APIVersion)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", APIUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", APIClient)
|
||||
req.Header.Set("Client-Metadata", ClientMetadata)
|
||||
|
||||
resp, errDo := o.httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return "", fmt.Errorf("execute request: %w", errDo)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||
if errRead != nil {
|
||||
return "", fmt.Errorf("read response: %w", errRead)
|
||||
}
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
||||
}
|
||||
|
||||
var loadResp map[string]any
|
||||
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
|
||||
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||
}
|
||||
|
||||
// Extract projectID from response
|
||||
projectID := ""
|
||||
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
if projectID == "" {
|
||||
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
|
||||
if id, okID := projectMap["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if projectID == "" {
|
||||
tierID := "legacy-tier"
|
||||
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
|
||||
for _, rawTier := range tiers {
|
||||
tier, okTier := rawTier.(map[string]any)
|
||||
if !okTier {
|
||||
continue
|
||||
}
|
||||
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
|
||||
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
|
||||
tierID = strings.TrimSpace(id)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
projectID, err = o.OnboardUser(ctx, accessToken, tierID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
// OnboardUser attempts to fetch the project ID via onboardUser by polling for completion
|
||||
func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
|
||||
log.Infof("Antigravity: onboarding user with tier: %s", tierID)
|
||||
requestBody := map[string]any{
|
||||
"tierId": tierID,
|
||||
"metadata": map[string]string{
|
||||
"ideType": "ANTIGRAVITY",
|
||||
"platform": "PLATFORM_UNSPECIFIED",
|
||||
"pluginType": "GEMINI",
|
||||
},
|
||||
}
|
||||
|
||||
rawBody, errMarshal := json.Marshal(requestBody)
|
||||
if errMarshal != nil {
|
||||
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
||||
}
|
||||
|
||||
maxAttempts := 5
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
|
||||
|
||||
reqCtx := ctx
|
||||
var cancel context.CancelFunc
|
||||
if reqCtx == nil {
|
||||
reqCtx = context.Background()
|
||||
}
|
||||
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
|
||||
|
||||
endpointURL := fmt.Sprintf("%s/%s:onboardUser", APIEndpoint, APIVersion)
|
||||
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||
if errRequest != nil {
|
||||
cancel()
|
||||
return "", fmt.Errorf("create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", APIUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", APIClient)
|
||||
req.Header.Set("Client-Metadata", ClientMetadata)
|
||||
|
||||
resp, errDo := o.httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
cancel()
|
||||
return "", fmt.Errorf("execute request: %w", errDo)
|
||||
}
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("close body error: %v", errClose)
|
||||
}
|
||||
cancel()
|
||||
|
||||
if errRead != nil {
|
||||
return "", fmt.Errorf("read response: %w", errRead)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
var data map[string]any
|
||||
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
|
||||
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||
}
|
||||
|
||||
if done, okDone := data["done"].(bool); okDone && done {
|
||||
projectID := ""
|
||||
if responseData, okResp := data["response"].(map[string]any); okResp {
|
||||
switch projectValue := responseData["cloudaicompanionProject"].(type) {
|
||||
case map[string]any:
|
||||
if id, okID := projectValue["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
case string:
|
||||
projectID = strings.TrimSpace(projectValue)
|
||||
}
|
||||
}
|
||||
|
||||
if projectID != "" {
|
||||
log.Infof("Successfully fetched project_id: %s", projectID)
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no project_id in response")
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
responsePreview := strings.TrimSpace(string(bodyBytes))
|
||||
if len(responsePreview) > 500 {
|
||||
responsePreview = responsePreview[:500]
|
||||
}
|
||||
|
||||
responseErr := responsePreview
|
||||
if len(responseErr) > 200 {
|
||||
responseErr = responseErr[:200]
|
||||
}
|
||||
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
34
internal/auth/antigravity/constants.go
Normal file
34
internal/auth/antigravity/constants.go
Normal file
@@ -0,0 +1,34 @@
|
||||
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
|
||||
package antigravity
|
||||
|
||||
// OAuth client credentials and configuration
|
||||
const (
|
||||
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
ClientSecret = "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"
|
||||
CallbackPort = 51121
|
||||
)
|
||||
|
||||
// Scopes defines the OAuth scopes required for Antigravity authentication
|
||||
var Scopes = []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/cclog",
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs",
|
||||
}
|
||||
|
||||
// OAuth2 endpoints for Google authentication
|
||||
const (
|
||||
TokenEndpoint = "https://oauth2.googleapis.com/token"
|
||||
AuthEndpoint = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
UserInfoEndpoint = "https://www.googleapis.com/oauth2/v1/userinfo?alt=json"
|
||||
)
|
||||
|
||||
// Antigravity API configuration
|
||||
const (
|
||||
APIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
APIVersion = "v1internal"
|
||||
APIUserAgent = "google-api-nodejs-client/9.15.1"
|
||||
APIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
|
||||
ClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`
|
||||
)
|
||||
16
internal/auth/antigravity/filename.go
Normal file
16
internal/auth/antigravity/filename.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CredentialFileName returns the filename used to persist Antigravity credentials.
|
||||
// It uses the email as a suffix to disambiguate accounts.
|
||||
func CredentialFileName(email string) string {
|
||||
email = strings.TrimSpace(email)
|
||||
if email == "" {
|
||||
return "antigravity.json"
|
||||
}
|
||||
return fmt.Sprintf("antigravity-%s.json", email)
|
||||
}
|
||||
@@ -14,15 +14,15 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// OAuth configuration constants for Claude/Anthropic
|
||||
const (
|
||||
anthropicAuthURL = "https://claude.ai/oauth/authorize"
|
||||
anthropicTokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
anthropicClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
redirectURI = "http://localhost:54545/callback"
|
||||
AuthURL = "https://claude.ai/oauth/authorize"
|
||||
TokenURL = "https://api.anthropic.com/v1/oauth/token"
|
||||
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
RedirectURI = "http://localhost:54545/callback"
|
||||
)
|
||||
|
||||
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
|
||||
@@ -50,7 +50,8 @@ type ClaudeAuth struct {
|
||||
}
|
||||
|
||||
// NewClaudeAuth creates a new Anthropic authentication service.
|
||||
// It initializes the HTTP client with proxy settings from the configuration.
|
||||
// It initializes the HTTP client with a custom TLS transport that uses Firefox
|
||||
// fingerprint to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration containing proxy settings
|
||||
@@ -58,8 +59,10 @@ type ClaudeAuth struct {
|
||||
// Returns:
|
||||
// - *ClaudeAuth: A new Claude authentication service instance
|
||||
func NewClaudeAuth(cfg *config.Config) *ClaudeAuth {
|
||||
// Use custom HTTP client with Firefox TLS fingerprint to bypass
|
||||
// Cloudflare's bot detection on Anthropic domains
|
||||
return &ClaudeAuth{
|
||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
|
||||
httpClient: NewAnthropicHttpClient(&cfg.SDKConfig),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,16 +85,16 @@ func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string
|
||||
|
||||
params := url.Values{
|
||||
"code": {"true"},
|
||||
"client_id": {anthropicClientID},
|
||||
"client_id": {ClientID},
|
||||
"response_type": {"code"},
|
||||
"redirect_uri": {redirectURI},
|
||||
"redirect_uri": {RedirectURI},
|
||||
"scope": {"org:create_api_key user:profile user:inference"},
|
||||
"code_challenge": {pkceCodes.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"state": {state},
|
||||
}
|
||||
|
||||
authURL := fmt.Sprintf("%s?%s", anthropicAuthURL, params.Encode())
|
||||
authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
|
||||
return authURL, state, nil
|
||||
}
|
||||
|
||||
@@ -137,8 +140,8 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri
|
||||
"code": newCode,
|
||||
"state": state,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": anthropicClientID,
|
||||
"redirect_uri": redirectURI,
|
||||
"client_id": ClientID,
|
||||
"redirect_uri": RedirectURI,
|
||||
"code_verifier": pkceCodes.CodeVerifier,
|
||||
}
|
||||
|
||||
@@ -154,7 +157,7 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri
|
||||
|
||||
// log.Debugf("Token exchange request: %s", string(jsonBody))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
@@ -221,7 +224,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"client_id": anthropicClientID,
|
||||
"client_id": ClientID,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refreshToken,
|
||||
}
|
||||
@@ -231,7 +234,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||
}
|
||||
|
||||
@@ -242,6 +242,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
||||
platformURL = "https://console.anthropic.com/"
|
||||
}
|
||||
|
||||
// Validate platformURL to prevent XSS - only allow http/https URLs
|
||||
if !isValidURL(platformURL) {
|
||||
platformURL = "https://console.anthropic.com/"
|
||||
}
|
||||
|
||||
// Generate success page HTML with dynamic content
|
||||
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
|
||||
|
||||
@@ -251,6 +256,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// isValidURL checks if the URL is a valid http/https URL to prevent XSS
|
||||
func isValidURL(urlStr string) bool {
|
||||
urlStr = strings.TrimSpace(urlStr)
|
||||
return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://")
|
||||
}
|
||||
|
||||
// generateSuccessHTML creates the HTML content for the success page.
|
||||
// It customizes the page based on whether additional setup is required
|
||||
// and includes a link to the platform.
|
||||
|
||||
@@ -36,11 +36,21 @@ type ClaudeTokenStorage struct {
|
||||
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Claude token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
// Encode and write the token data as JSON
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
167
internal/auth/claude/utls_transport.go
Normal file
167
internal/auth/claude/utls_transport.go
Normal file
@@ -0,0 +1,167 @@
|
||||
// Package claude provides authentication functionality for Anthropic's Claude API.
|
||||
// This file implements a custom HTTP transport using utls to bypass TLS fingerprinting.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
tls "github.com/refraction-networking/utls"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// utlsRoundTripper implements http.RoundTripper using utls with Chrome fingerprint
|
||||
// to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
|
||||
type utlsRoundTripper struct {
|
||||
// mu protects the connections map and pending map
|
||||
mu sync.Mutex
|
||||
// connections caches HTTP/2 client connections per host
|
||||
connections map[string]*http2.ClientConn
|
||||
// pending tracks hosts that are currently being connected to (prevents race condition)
|
||||
pending map[string]*sync.Cond
|
||||
// dialer is used to create network connections, supporting proxies
|
||||
dialer proxy.Dialer
|
||||
}
|
||||
|
||||
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
|
||||
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
|
||||
var dialer proxy.Dialer = proxy.Direct
|
||||
if cfg != nil && cfg.ProxyURL != "" {
|
||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse proxy URL %q: %v", cfg.ProxyURL, err)
|
||||
} else {
|
||||
pDialer, err := proxy.FromURL(proxyURL, proxy.Direct)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create proxy dialer for %q: %v", cfg.ProxyURL, err)
|
||||
} else {
|
||||
dialer = pDialer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &utlsRoundTripper{
|
||||
connections: make(map[string]*http2.ClientConn),
|
||||
pending: make(map[string]*sync.Cond),
|
||||
dialer: dialer,
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateConnection gets an existing connection or creates a new one.
|
||||
// It uses a per-host locking mechanism to prevent multiple goroutines from
|
||||
// creating connections to the same host simultaneously.
|
||||
func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
t.mu.Lock()
|
||||
|
||||
// Check if connection exists and is usable
|
||||
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
|
||||
t.mu.Unlock()
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
// Check if another goroutine is already creating a connection
|
||||
if cond, ok := t.pending[host]; ok {
|
||||
// Wait for the other goroutine to finish
|
||||
cond.Wait()
|
||||
// Check if connection is now available
|
||||
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
|
||||
t.mu.Unlock()
|
||||
return h2Conn, nil
|
||||
}
|
||||
// Connection still not available, we'll create one
|
||||
}
|
||||
|
||||
// Mark this host as pending
|
||||
cond := sync.NewCond(&t.mu)
|
||||
t.pending[host] = cond
|
||||
t.mu.Unlock()
|
||||
|
||||
// Create connection outside the lock
|
||||
h2Conn, err := t.createConnection(host, addr)
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// Remove pending marker and wake up waiting goroutines
|
||||
delete(t.pending, host)
|
||||
cond.Broadcast()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store the new connection
|
||||
t.connections[host] = h2Conn
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
// createConnection creates a new HTTP/2 connection with Chrome TLS fingerprint.
|
||||
// Chrome's TLS fingerprint is closer to Node.js/OpenSSL (which real Claude Code uses)
|
||||
// than Firefox, reducing the mismatch between TLS layer and HTTP headers.
|
||||
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
conn, err := t.dialer.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{ServerName: host}
|
||||
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloChrome_Auto)
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tr := &http2.Transport{}
|
||||
h2Conn, err := tr.NewClientConn(tlsConn)
|
||||
if err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper
|
||||
func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
host := req.URL.Host
|
||||
addr := host
|
||||
if !strings.Contains(addr, ":") {
|
||||
addr += ":443"
|
||||
}
|
||||
|
||||
// Get hostname without port for TLS ServerName
|
||||
hostname := req.URL.Hostname()
|
||||
|
||||
h2Conn, err := t.getOrCreateConnection(hostname, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := h2Conn.RoundTrip(req)
|
||||
if err != nil {
|
||||
// Connection failed, remove it from cache
|
||||
t.mu.Lock()
|
||||
if cached, ok := t.connections[hostname]; ok && cached == h2Conn {
|
||||
delete(t.connections, hostname)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// NewAnthropicHttpClient creates an HTTP client that bypasses TLS fingerprinting
|
||||
// for Anthropic domains by using utls with Chrome fingerprint.
|
||||
// It accepts optional SDK configuration for proxy settings.
|
||||
func NewAnthropicHttpClient(cfg *config.SDKConfig) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: newUtlsRoundTripper(cfg),
|
||||
}
|
||||
}
|
||||
46
internal/auth/codex/filename.go
Normal file
46
internal/auth/codex/filename.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package codex
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// CredentialFileName returns the filename used to persist Codex OAuth credentials.
|
||||
// When planType is available (e.g. "plus", "team"), it is appended after the email
|
||||
// as a suffix to disambiguate subscriptions.
|
||||
func CredentialFileName(email, planType, hashAccountID string, includeProviderPrefix bool) string {
|
||||
email = strings.TrimSpace(email)
|
||||
plan := normalizePlanTypeForFilename(planType)
|
||||
|
||||
prefix := ""
|
||||
if includeProviderPrefix {
|
||||
prefix = "codex"
|
||||
}
|
||||
|
||||
if plan == "" {
|
||||
return fmt.Sprintf("%s-%s.json", prefix, email)
|
||||
} else if plan == "team" {
|
||||
return fmt.Sprintf("%s-%s-%s-%s.json", prefix, hashAccountID, email, plan)
|
||||
}
|
||||
return fmt.Sprintf("%s-%s-%s.json", prefix, email, plan)
|
||||
}
|
||||
|
||||
func normalizePlanTypeForFilename(planType string) string {
|
||||
planType = strings.TrimSpace(planType)
|
||||
if planType == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := strings.FieldsFunc(planType, func(r rune) bool {
|
||||
return !unicode.IsLetter(r) && !unicode.IsDigit(r)
|
||||
})
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
for i, part := range parts {
|
||||
parts[i] = strings.ToLower(strings.TrimSpace(part))
|
||||
}
|
||||
return strings.Join(parts, "-")
|
||||
}
|
||||
@@ -239,6 +239,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
||||
platformURL = "https://platform.openai.com"
|
||||
}
|
||||
|
||||
// Validate platformURL to prevent XSS - only allow http/https URLs
|
||||
if !isValidURL(platformURL) {
|
||||
platformURL = "https://platform.openai.com"
|
||||
}
|
||||
|
||||
// Generate success page HTML with dynamic content
|
||||
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
|
||||
|
||||
@@ -248,6 +253,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// isValidURL checks if the URL is a valid http/https URL to prevent XSS
|
||||
func isValidURL(urlStr string) bool {
|
||||
urlStr = strings.TrimSpace(urlStr)
|
||||
return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://")
|
||||
}
|
||||
|
||||
// generateSuccessHTML creates the HTML content for the success page.
|
||||
// It customizes the page based on whether additional setup is required
|
||||
// and includes a link to the platform.
|
||||
|
||||
@@ -19,11 +19,12 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// OAuth configuration constants for OpenAI Codex
|
||||
const (
|
||||
openaiAuthURL = "https://auth.openai.com/oauth/authorize"
|
||||
openaiTokenURL = "https://auth.openai.com/oauth/token"
|
||||
openaiClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
redirectURI = "http://localhost:1455/auth/callback"
|
||||
AuthURL = "https://auth.openai.com/oauth/authorize"
|
||||
TokenURL = "https://auth.openai.com/oauth/token"
|
||||
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
RedirectURI = "http://localhost:1455/auth/callback"
|
||||
)
|
||||
|
||||
// CodexAuth handles the OpenAI OAuth2 authentication flow.
|
||||
@@ -50,9 +51,9 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
||||
}
|
||||
|
||||
params := url.Values{
|
||||
"client_id": {openaiClientID},
|
||||
"client_id": {ClientID},
|
||||
"response_type": {"code"},
|
||||
"redirect_uri": {redirectURI},
|
||||
"redirect_uri": {RedirectURI},
|
||||
"scope": {"openid email profile offline_access"},
|
||||
"state": {state},
|
||||
"code_challenge": {pkceCodes.CodeChallenge},
|
||||
@@ -62,7 +63,7 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
||||
"codex_cli_simplified_flow": {"true"},
|
||||
}
|
||||
|
||||
authURL := fmt.Sprintf("%s?%s", openaiAuthURL, params.Encode())
|
||||
authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
|
||||
return authURL, nil
|
||||
}
|
||||
|
||||
@@ -70,20 +71,30 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
||||
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
|
||||
// authorization code and PKCE verifier.
|
||||
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||
return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes)
|
||||
}
|
||||
|
||||
// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using
|
||||
// a caller-provided redirect URI. This supports alternate auth flows such as device
|
||||
// login while preserving the existing token parsing and storage behavior.
|
||||
func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||
if pkceCodes == nil {
|
||||
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
||||
}
|
||||
if strings.TrimSpace(redirectURI) == "" {
|
||||
return nil, fmt.Errorf("redirect URI is required for token exchange")
|
||||
}
|
||||
|
||||
// Prepare token exchange request
|
||||
data := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {openaiClientID},
|
||||
"client_id": {ClientID},
|
||||
"code": {code},
|
||||
"redirect_uri": {redirectURI},
|
||||
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||
"code_verifier": {pkceCodes.CodeVerifier},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
@@ -163,13 +174,13 @@ func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*Co
|
||||
}
|
||||
|
||||
data := url.Values{
|
||||
"client_id": {openaiClientID},
|
||||
"client_id": {ClientID},
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {refreshToken},
|
||||
"scope": {"openid profile email"},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||
}
|
||||
@@ -265,6 +276,10 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
||||
if err == nil {
|
||||
return tokenData, nil
|
||||
}
|
||||
if isNonRetryableRefreshErr(err) {
|
||||
log.Warnf("Token refresh attempt %d failed with non-retryable error: %v", attempt+1, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
||||
@@ -273,6 +288,14 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
func isNonRetryableRefreshErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
raw := strings.ToLower(err.Error())
|
||||
return strings.Contains(raw, "refresh_token_reused")
|
||||
}
|
||||
|
||||
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
|
||||
// This is typically called after a successful token refresh to persist the new credentials.
|
||||
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
|
||||
|
||||
44
internal/auth/codex/openai_auth_test.go
Normal file
44
internal/auth/codex/openai_auth_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package codex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
|
||||
var calls int32
|
||||
auth := &CodexAuth{
|
||||
httpClient: &http.Client{
|
||||
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
|
||||
Header: make(http.Header),
|
||||
Request: req,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for non-retryable refresh failure")
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
|
||||
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
|
||||
}
|
||||
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||
t.Fatalf("expected 1 refresh attempt, got %d", got)
|
||||
}
|
||||
}
|
||||
@@ -32,11 +32,21 @@ type CodexTokenStorage struct {
|
||||
Type string `json:"type"`
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Codex token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -82,15 +84,21 @@ func (c *CopilotAuth) WaitForAuthorization(ctx context.Context, deviceCode *Devi
|
||||
}
|
||||
|
||||
// Fetch the GitHub username
|
||||
username, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||
userInfo, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||
if err != nil {
|
||||
log.Warnf("copilot: failed to fetch user info: %v", err)
|
||||
username = "unknown"
|
||||
}
|
||||
|
||||
username := userInfo.Login
|
||||
if username == "" {
|
||||
username = "github-user"
|
||||
}
|
||||
|
||||
return &CopilotAuthBundle{
|
||||
TokenData: tokenData,
|
||||
Username: username,
|
||||
Email: userInfo.Email,
|
||||
Name: userInfo.Name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -150,12 +158,12 @@ func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bo
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
username, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
|
||||
userInfo, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
return true, username, nil
|
||||
return true, userInfo.Login, nil
|
||||
}
|
||||
|
||||
// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle.
|
||||
@@ -165,6 +173,8 @@ func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotToke
|
||||
TokenType: bundle.TokenData.TokenType,
|
||||
Scope: bundle.TokenData.Scope,
|
||||
Username: bundle.Username,
|
||||
Email: bundle.Email,
|
||||
Name: bundle.Name,
|
||||
Type: "github-copilot",
|
||||
}
|
||||
}
|
||||
@@ -214,6 +224,97 @@ func (c *CopilotAuth) MakeAuthenticatedRequest(ctx context.Context, method, url
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// CopilotModelEntry represents a single model entry returned by the Copilot /models API.
|
||||
type CopilotModelEntry struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Capabilities map[string]any `json:"capabilities,omitempty"`
|
||||
}
|
||||
|
||||
// CopilotModelsResponse represents the response from the Copilot /models endpoint.
|
||||
type CopilotModelsResponse struct {
|
||||
Data []CopilotModelEntry `json:"data"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
// maxModelsResponseSize is the maximum allowed response size from the /models endpoint (2 MB).
|
||||
const maxModelsResponseSize = 2 * 1024 * 1024
|
||||
|
||||
// allowedCopilotAPIHosts is the set of hosts that are considered safe for Copilot API requests.
|
||||
var allowedCopilotAPIHosts = map[string]bool{
|
||||
"api.githubcopilot.com": true,
|
||||
"api.individual.githubcopilot.com": true,
|
||||
"api.business.githubcopilot.com": true,
|
||||
"copilot-proxy.githubusercontent.com": true,
|
||||
}
|
||||
|
||||
// ListModels fetches the list of available models from the Copilot API.
|
||||
// It requires a valid Copilot API token (not the GitHub access token).
|
||||
func (c *CopilotAuth) ListModels(ctx context.Context, apiToken *CopilotAPIToken) ([]CopilotModelEntry, error) {
|
||||
if apiToken == nil || apiToken.Token == "" {
|
||||
return nil, fmt.Errorf("copilot: api token is required for listing models")
|
||||
}
|
||||
|
||||
// Build models URL, validating the endpoint host to prevent SSRF.
|
||||
modelsURL := copilotAPIEndpoint + "/models"
|
||||
if ep := strings.TrimRight(apiToken.Endpoints.API, "/"); ep != "" {
|
||||
parsed, err := url.Parse(ep)
|
||||
if err == nil && parsed.Scheme == "https" && allowedCopilotAPIHosts[parsed.Host] {
|
||||
modelsURL = ep + "/models"
|
||||
} else {
|
||||
log.Warnf("copilot: ignoring untrusted API endpoint %q, using default", ep)
|
||||
}
|
||||
}
|
||||
|
||||
req, err := c.MakeAuthenticatedRequest(ctx, http.MethodGet, modelsURL, nil, apiToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("copilot: failed to create models request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("copilot: models request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("copilot list models: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
// Limit response body to prevent memory exhaustion.
|
||||
limitedReader := io.LimitReader(resp.Body, maxModelsResponseSize)
|
||||
bodyBytes, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("copilot: failed to read models response: %w", err)
|
||||
}
|
||||
|
||||
if !isHTTPSuccess(resp.StatusCode) {
|
||||
return nil, fmt.Errorf("copilot: list models failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var modelsResp CopilotModelsResponse
|
||||
if err = json.Unmarshal(bodyBytes, &modelsResp); err != nil {
|
||||
return nil, fmt.Errorf("copilot: failed to parse models response: %w", err)
|
||||
}
|
||||
|
||||
return modelsResp.Data, nil
|
||||
}
|
||||
|
||||
// ListModelsWithGitHubToken is a convenience method that exchanges a GitHub access token
|
||||
// for a Copilot API token and then fetches the available models.
|
||||
func (c *CopilotAuth) ListModelsWithGitHubToken(ctx context.Context, githubAccessToken string) ([]CopilotModelEntry, error) {
|
||||
apiToken, err := c.GetCopilotAPIToken(ctx, githubAccessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("copilot: failed to get API token for model listing: %w", err)
|
||||
}
|
||||
|
||||
return c.ListModels(ctx, apiToken)
|
||||
}
|
||||
|
||||
// buildChatCompletionURL builds the URL for chat completions API.
|
||||
func buildChatCompletionURL() string {
|
||||
return copilotAPIEndpoint + "/chat/completions"
|
||||
|
||||
@@ -53,7 +53,7 @@ func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
|
||||
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", copilotClientID)
|
||||
data.Set("scope", "user:email")
|
||||
data.Set("scope", "read:user user:email")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
@@ -211,15 +211,25 @@ func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode st
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FetchUserInfo retrieves the GitHub username for the authenticated user.
|
||||
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
|
||||
// GitHubUserInfo holds GitHub user profile information.
|
||||
type GitHubUserInfo struct {
|
||||
// Login is the GitHub username.
|
||||
Login string
|
||||
// Email is the primary email address (may be empty if not public).
|
||||
Email string
|
||||
// Name is the display name.
|
||||
Name string
|
||||
}
|
||||
|
||||
// FetchUserInfo retrieves the GitHub user profile for the authenticated user.
|
||||
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (GitHubUserInfo, error) {
|
||||
if accessToken == "" {
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil)
|
||||
if err != nil {
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
@@ -227,7 +237,7 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
@@ -237,19 +247,25 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
|
||||
|
||||
if !isHTTPSuccess(resp.StatusCode) {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
||||
}
|
||||
|
||||
var userInfo struct {
|
||||
var raw struct {
|
||||
Login string `json:"login"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
if err = json.NewDecoder(resp.Body).Decode(&raw); err != nil {
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
}
|
||||
|
||||
if userInfo.Login == "" {
|
||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
|
||||
if raw.Login == "" {
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
|
||||
}
|
||||
|
||||
return userInfo.Login, nil
|
||||
return GitHubUserInfo{
|
||||
Login: raw.Login,
|
||||
Email: raw.Email,
|
||||
Name: raw.Name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
213
internal/auth/copilot/oauth_test.go
Normal file
213
internal/auth/copilot/oauth_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// roundTripFunc lets us inject a custom transport for testing.
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
|
||||
|
||||
// newTestClient returns an *http.Client whose requests are redirected to the given test server,
|
||||
// regardless of the original URL host.
|
||||
func newTestClient(srv *httptest.Server) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
req2 := req.Clone(req.Context())
|
||||
req2.URL.Scheme = "http"
|
||||
req2.URL.Host = strings.TrimPrefix(srv.URL, "http://")
|
||||
return srv.Client().Transport.RoundTrip(req2)
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_FullProfile verifies that FetchUserInfo returns login, email, and name.
|
||||
func TestFetchUserInfo_FullProfile(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"login": "octocat",
|
||||
"email": "octocat@github.com",
|
||||
"name": "The Octocat",
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||
info, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if info.Login != "octocat" {
|
||||
t.Errorf("Login: got %q, want %q", info.Login, "octocat")
|
||||
}
|
||||
if info.Email != "octocat@github.com" {
|
||||
t.Errorf("Email: got %q, want %q", info.Email, "octocat@github.com")
|
||||
}
|
||||
if info.Name != "The Octocat" {
|
||||
t.Errorf("Name: got %q, want %q", info.Name, "The Octocat")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_EmptyEmail verifies graceful handling when email is absent (private account).
|
||||
func TestFetchUserInfo_EmptyEmail(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// GitHub returns null for private emails.
|
||||
_, _ = w.Write([]byte(`{"login":"privateuser","email":null,"name":"Private User"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||
info, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if info.Login != "privateuser" {
|
||||
t.Errorf("Login: got %q, want %q", info.Login, "privateuser")
|
||||
}
|
||||
if info.Email != "" {
|
||||
t.Errorf("Email: got %q, want empty string", info.Email)
|
||||
}
|
||||
if info.Name != "Private User" {
|
||||
t.Errorf("Name: got %q, want %q", info.Name, "Private User")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_EmptyToken verifies error is returned for empty access token.
|
||||
func TestFetchUserInfo_EmptyToken(t *testing.T) {
|
||||
client := &DeviceFlowClient{httpClient: http.DefaultClient}
|
||||
_, err := client.FetchUserInfo(context.Background(), "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty token, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_EmptyLogin verifies error is returned when API returns no login.
|
||||
func TestFetchUserInfo_EmptyLogin(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"email":"someone@example.com","name":"No Login"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||
_, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty login, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_HTTPError verifies error is returned on non-2xx response.
|
||||
func TestFetchUserInfo_HTTPError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"message":"Bad credentials"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||
_, err := client.FetchUserInfo(context.Background(), "bad-token")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 401 response, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopilotTokenStorage_EmailNameFields verifies Email and Name serialise correctly.
|
||||
func TestCopilotTokenStorage_EmailNameFields(t *testing.T) {
|
||||
ts := &CopilotTokenStorage{
|
||||
AccessToken: "ghu_abc",
|
||||
TokenType: "bearer",
|
||||
Scope: "read:user user:email",
|
||||
Username: "octocat",
|
||||
Email: "octocat@github.com",
|
||||
Name: "The Octocat",
|
||||
Type: "github-copilot",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(ts)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal error: %v", err)
|
||||
}
|
||||
|
||||
var out map[string]any
|
||||
if err = json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
|
||||
for _, key := range []string{"access_token", "username", "email", "name", "type"} {
|
||||
if _, ok := out[key]; !ok {
|
||||
t.Errorf("expected key %q in JSON output, not found", key)
|
||||
}
|
||||
}
|
||||
if out["email"] != "octocat@github.com" {
|
||||
t.Errorf("email: got %v, want %q", out["email"], "octocat@github.com")
|
||||
}
|
||||
if out["name"] != "The Octocat" {
|
||||
t.Errorf("name: got %v, want %q", out["name"], "The Octocat")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopilotTokenStorage_OmitEmptyEmailName verifies email/name are omitted when empty (omitempty).
|
||||
func TestCopilotTokenStorage_OmitEmptyEmailName(t *testing.T) {
|
||||
ts := &CopilotTokenStorage{
|
||||
AccessToken: "ghu_abc",
|
||||
Username: "octocat",
|
||||
Type: "github-copilot",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(ts)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal error: %v", err)
|
||||
}
|
||||
|
||||
var out map[string]any
|
||||
if err = json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := out["email"]; ok {
|
||||
t.Error("email key should be omitted when empty (omitempty), but was present")
|
||||
}
|
||||
if _, ok := out["name"]; ok {
|
||||
t.Error("name key should be omitted when empty (omitempty), but was present")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopilotAuthBundle_EmailNameFields verifies bundle carries email and name through the pipeline.
|
||||
func TestCopilotAuthBundle_EmailNameFields(t *testing.T) {
|
||||
bundle := &CopilotAuthBundle{
|
||||
TokenData: &CopilotTokenData{AccessToken: "ghu_abc"},
|
||||
Username: "octocat",
|
||||
Email: "octocat@github.com",
|
||||
Name: "The Octocat",
|
||||
}
|
||||
if bundle.Email != "octocat@github.com" {
|
||||
t.Errorf("bundle.Email: got %q, want %q", bundle.Email, "octocat@github.com")
|
||||
}
|
||||
if bundle.Name != "The Octocat" {
|
||||
t.Errorf("bundle.Name: got %q, want %q", bundle.Name, "The Octocat")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGitHubUserInfo_Struct verifies the exported GitHubUserInfo struct fields are accessible.
|
||||
func TestGitHubUserInfo_Struct(t *testing.T) {
|
||||
info := GitHubUserInfo{
|
||||
Login: "octocat",
|
||||
Email: "octocat@github.com",
|
||||
Name: "The Octocat",
|
||||
}
|
||||
if info.Login == "" || info.Email == "" || info.Name == "" {
|
||||
t.Error("GitHubUserInfo fields should not be empty")
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,10 @@ type CopilotTokenStorage struct {
|
||||
ExpiresAt string `json:"expires_at,omitempty"`
|
||||
// Username is the GitHub username associated with this token.
|
||||
Username string `json:"username"`
|
||||
// Email is the GitHub email address associated with this token.
|
||||
Email string `json:"email,omitempty"`
|
||||
// Name is the GitHub display name associated with this token.
|
||||
Name string `json:"name,omitempty"`
|
||||
// Type indicates the authentication provider type, always "github-copilot" for this storage.
|
||||
Type string `json:"type"`
|
||||
}
|
||||
@@ -46,6 +50,10 @@ type CopilotAuthBundle struct {
|
||||
TokenData *CopilotTokenData
|
||||
// Username is the GitHub username.
|
||||
Username string
|
||||
// Email is the GitHub email address.
|
||||
Email string
|
||||
// Name is the GitHub display name.
|
||||
Name string
|
||||
}
|
||||
|
||||
// DeviceCodeResponse represents GitHub's device code response.
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -27,18 +28,19 @@ import (
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
// OAuth configuration constants for Gemini
|
||||
const (
|
||||
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
geminiOauthClientSecret = "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"
|
||||
ClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
ClientSecret = "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"
|
||||
DefaultCallbackPort = 8085
|
||||
)
|
||||
|
||||
var (
|
||||
geminiOauthScopes = []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
}
|
||||
)
|
||||
// OAuth scopes for Gemini authentication
|
||||
var Scopes = []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
}
|
||||
|
||||
// GeminiAuth provides methods for handling the Gemini OAuth2 authentication flow.
|
||||
// It encapsulates the logic for obtaining, storing, and refreshing authentication tokens
|
||||
@@ -46,6 +48,13 @@ var (
|
||||
type GeminiAuth struct {
|
||||
}
|
||||
|
||||
// WebLoginOptions customizes the interactive OAuth flow.
|
||||
type WebLoginOptions struct {
|
||||
NoBrowser bool
|
||||
CallbackPort int
|
||||
Prompt func(string) (string, error)
|
||||
}
|
||||
|
||||
// NewGeminiAuth creates a new instance of GeminiAuth.
|
||||
func NewGeminiAuth() *GeminiAuth {
|
||||
return &GeminiAuth{}
|
||||
@@ -59,12 +68,18 @@ func NewGeminiAuth() *GeminiAuth {
|
||||
// - ctx: The context for the HTTP client
|
||||
// - ts: The Gemini token storage containing authentication tokens
|
||||
// - cfg: The configuration containing proxy settings
|
||||
// - noBrowser: Optional parameter to disable browser opening
|
||||
// - opts: Optional parameters to customize browser and prompt behavior
|
||||
//
|
||||
// Returns:
|
||||
// - *http.Client: An HTTP client configured with authentication
|
||||
// - error: An error if the client configuration fails, nil otherwise
|
||||
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, noBrowser ...bool) (*http.Client, error) {
|
||||
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
|
||||
callbackPort := DefaultCallbackPort
|
||||
if opts != nil && opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||
|
||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||
if err == nil {
|
||||
@@ -76,7 +91,8 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
||||
auth := &proxy.Auth{User: username, Password: password}
|
||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
|
||||
if errSOCKS5 != nil {
|
||||
log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||
return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
|
||||
}
|
||||
transport = &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
@@ -96,10 +112,10 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
||||
|
||||
// Configure the OAuth2 client.
|
||||
conf := &oauth2.Config{
|
||||
ClientID: geminiOauthClientID,
|
||||
ClientSecret: geminiOauthClientSecret,
|
||||
RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server.
|
||||
Scopes: geminiOauthScopes,
|
||||
ClientID: ClientID,
|
||||
ClientSecret: ClientSecret,
|
||||
RedirectURL: callbackURL, // This will be used by the local server.
|
||||
Scopes: Scopes,
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
|
||||
@@ -108,7 +124,7 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
||||
// If no token is found in storage, initiate the web-based OAuth flow.
|
||||
if ts.Token == nil {
|
||||
fmt.Printf("Could not load token from file, starting OAuth flow.\n")
|
||||
token, err = g.getTokenFromWeb(ctx, conf, noBrowser...)
|
||||
token, err = g.getTokenFromWeb(ctx, conf, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token from web: %w", err)
|
||||
}
|
||||
@@ -182,9 +198,9 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
|
||||
}
|
||||
|
||||
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
||||
ifToken["client_id"] = geminiOauthClientID
|
||||
ifToken["client_secret"] = geminiOauthClientSecret
|
||||
ifToken["scopes"] = geminiOauthScopes
|
||||
ifToken["client_id"] = ClientID
|
||||
ifToken["client_secret"] = ClientSecret
|
||||
ifToken["scopes"] = Scopes
|
||||
ifToken["universe_domain"] = "googleapis.com"
|
||||
|
||||
ts := GeminiTokenStorage{
|
||||
@@ -204,60 +220,84 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
|
||||
// Parameters:
|
||||
// - ctx: The context for the HTTP client
|
||||
// - config: The OAuth2 configuration
|
||||
// - noBrowser: Optional parameter to disable browser opening
|
||||
// - opts: Optional parameters to customize browser and prompt behavior
|
||||
//
|
||||
// Returns:
|
||||
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
|
||||
// - error: An error if the token acquisition fails, nil otherwise
|
||||
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, noBrowser ...bool) (*oauth2.Token, error) {
|
||||
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
|
||||
callbackPort := DefaultCallbackPort
|
||||
if opts != nil && opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||
|
||||
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
||||
codeChan := make(chan string)
|
||||
errChan := make(chan error)
|
||||
codeChan := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
// Create a new HTTP server with its own multiplexer.
|
||||
mux := http.NewServeMux()
|
||||
server := &http.Server{Addr: ":8085", Handler: mux}
|
||||
config.RedirectURL = "http://localhost:8085/oauth2callback"
|
||||
server := &http.Server{Addr: fmt.Sprintf(":%d", callbackPort), Handler: mux}
|
||||
config.RedirectURL = callbackURL
|
||||
|
||||
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.URL.Query().Get("error"); err != "" {
|
||||
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
|
||||
errChan <- fmt.Errorf("authentication failed via callback: %s", err)
|
||||
select {
|
||||
case errChan <- fmt.Errorf("authentication failed via callback: %s", err):
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
_, _ = fmt.Fprint(w, "Authentication failed: code not found.")
|
||||
errChan <- fmt.Errorf("code not found in callback")
|
||||
select {
|
||||
case errChan <- fmt.Errorf("code not found in callback"):
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
_, _ = fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
|
||||
codeChan <- code
|
||||
select {
|
||||
case codeChan <- code:
|
||||
default:
|
||||
}
|
||||
})
|
||||
|
||||
// Start the server in a goroutine.
|
||||
go func() {
|
||||
if err := server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("ListenAndServe(): %v", err)
|
||||
log.Errorf("ListenAndServe(): %v", err)
|
||||
select {
|
||||
case errChan <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Open the authorization URL in the user's browser.
|
||||
authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
||||
|
||||
if len(noBrowser) == 1 && !noBrowser[0] {
|
||||
noBrowser := false
|
||||
if opts != nil {
|
||||
noBrowser = opts.NoBrowser
|
||||
}
|
||||
|
||||
if !noBrowser {
|
||||
fmt.Println("Opening browser for authentication...")
|
||||
|
||||
// Check if browser is available
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available on this system")
|
||||
util.PrintSSHTunnelInstructions(8085)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||
} else {
|
||||
if err := browser.OpenURL(authURL); err != nil {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err)
|
||||
log.Warn(codex.GetUserFriendlyMessage(authErr))
|
||||
util.PrintSSHTunnelInstructions(8085)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||
|
||||
// Log platform info for debugging
|
||||
@@ -268,7 +308,7 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
util.PrintSSHTunnelInstructions(8085)
|
||||
util.PrintSSHTunnelInstructions(callbackPort)
|
||||
fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL)
|
||||
}
|
||||
|
||||
@@ -276,13 +316,60 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
||||
|
||||
// Wait for the authorization code or an error.
|
||||
var authCode string
|
||||
select {
|
||||
case code := <-codeChan:
|
||||
authCode = code
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-time.After(5 * time.Minute): // Timeout
|
||||
return nil, fmt.Errorf("oauth flow timed out")
|
||||
timeoutTimer := time.NewTimer(5 * time.Minute)
|
||||
defer timeoutTimer.Stop()
|
||||
|
||||
var manualPromptTimer *time.Timer
|
||||
var manualPromptC <-chan time.Time
|
||||
if opts != nil && opts.Prompt != nil {
|
||||
manualPromptTimer = time.NewTimer(15 * time.Second)
|
||||
manualPromptC = manualPromptTimer.C
|
||||
defer manualPromptTimer.Stop()
|
||||
}
|
||||
|
||||
waitForCallback:
|
||||
for {
|
||||
select {
|
||||
case code := <-codeChan:
|
||||
authCode = code
|
||||
break waitForCallback
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-manualPromptC:
|
||||
manualPromptC = nil
|
||||
if manualPromptTimer != nil {
|
||||
manualPromptTimer.Stop()
|
||||
}
|
||||
select {
|
||||
case code := <-codeChan:
|
||||
authCode = code
|
||||
break waitForCallback
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
default:
|
||||
}
|
||||
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsed, err := misc.ParseOAuthCallback(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsed == nil {
|
||||
continue
|
||||
}
|
||||
if parsed.Error != "" {
|
||||
return nil, fmt.Errorf("authentication failed via callback: %s", parsed.Error)
|
||||
}
|
||||
if parsed.Code == "" {
|
||||
return nil, fmt.Errorf("code not found in callback")
|
||||
}
|
||||
authCode = parsed.Code
|
||||
break waitForCallback
|
||||
case <-timeoutTimer.C:
|
||||
return nil, fmt.Errorf("oauth flow timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown the server.
|
||||
|
||||
@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
|
||||
|
||||
// Type indicates the authentication provider type, always "gemini" for this storage.
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -49,6 +59,11 @@ type GeminiTokenStorage struct {
|
||||
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "gemini"
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
@@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
}
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package iflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -36,3 +39,61 @@ func SanitizeIFlowFileName(raw string) string {
|
||||
}
|
||||
return strings.TrimSpace(result.String())
|
||||
}
|
||||
|
||||
// ExtractBXAuth extracts the BXAuth value from a cookie string.
|
||||
func ExtractBXAuth(cookie string) string {
|
||||
parts := strings.Split(cookie, ";")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "BXAuth=") {
|
||||
return strings.TrimPrefix(part, "BXAuth=")
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// CheckDuplicateBXAuth checks if the given BXAuth value already exists in any iflow auth file.
|
||||
// Returns the path of the existing file if found, empty string otherwise.
|
||||
func CheckDuplicateBXAuth(authDir, bxAuth string) (string, error) {
|
||||
if bxAuth == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(authDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("read auth dir failed: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !strings.HasPrefix(name, "iflow-") || !strings.HasSuffix(name, ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
filePath := filepath.Join(authDir, name)
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var tokenData struct {
|
||||
Cookie string `json:"cookie"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &tokenData); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
existingBXAuth := ExtractBXAuth(tokenData.Cookie)
|
||||
if existingBXAuth != "" && existingBXAuth == bxAuth {
|
||||
return filePath, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -28,10 +29,21 @@ const (
|
||||
iFlowAPIKeyEndpoint = "https://platform.iflow.cn/api/openapi/apikey"
|
||||
|
||||
// Client credentials provided by iFlow for the Code Assist integration.
|
||||
iFlowOAuthClientID = "10009311001"
|
||||
iFlowOAuthClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
|
||||
iFlowOAuthClientID = "10009311001"
|
||||
// Default client secret (can be overridden via IFLOW_CLIENT_SECRET env var)
|
||||
defaultIFlowClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
|
||||
)
|
||||
|
||||
// getIFlowClientSecret returns the iFlow OAuth client secret.
|
||||
// It first checks the IFLOW_CLIENT_SECRET environment variable,
|
||||
// falling back to the default value if not set.
|
||||
func getIFlowClientSecret() string {
|
||||
if secret := os.Getenv("IFLOW_CLIENT_SECRET"); secret != "" {
|
||||
return secret
|
||||
}
|
||||
return defaultIFlowClientSecret
|
||||
}
|
||||
|
||||
// DefaultAPIBaseURL is the canonical chat completions endpoint.
|
||||
const DefaultAPIBaseURL = "https://apis.iflow.cn/v1"
|
||||
|
||||
@@ -72,7 +84,7 @@ func (ia *IFlowAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectUR
|
||||
form.Set("code", code)
|
||||
form.Set("redirect_uri", redirectURI)
|
||||
form.Set("client_id", iFlowOAuthClientID)
|
||||
form.Set("client_secret", iFlowOAuthClientSecret)
|
||||
form.Set("client_secret", getIFlowClientSecret())
|
||||
|
||||
req, err := ia.newTokenRequest(ctx, form)
|
||||
if err != nil {
|
||||
@@ -88,7 +100,7 @@ func (ia *IFlowAuth) RefreshTokens(ctx context.Context, refreshToken string) (*I
|
||||
form.Set("grant_type", "refresh_token")
|
||||
form.Set("refresh_token", refreshToken)
|
||||
form.Set("client_id", iFlowOAuthClientID)
|
||||
form.Set("client_secret", iFlowOAuthClientSecret)
|
||||
form.Set("client_secret", getIFlowClientSecret())
|
||||
|
||||
req, err := ia.newTokenRequest(ctx, form)
|
||||
if err != nil {
|
||||
@@ -104,7 +116,7 @@ func (ia *IFlowAuth) newTokenRequest(ctx context.Context, form url.Values) (*htt
|
||||
return nil, fmt.Errorf("iflow token: create request failed: %w", err)
|
||||
}
|
||||
|
||||
basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + iFlowOAuthClientSecret))
|
||||
basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + getIFlowClientSecret()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Basic "+basic)
|
||||
@@ -309,17 +321,23 @@ func (ia *IFlowAuth) AuthenticateWithCookie(ctx context.Context, cookie string)
|
||||
return nil, fmt.Errorf("iflow cookie authentication: cookie is empty")
|
||||
}
|
||||
|
||||
// First, get initial API key information using GET request
|
||||
// First, get initial API key information using GET request to obtain the name
|
||||
keyInfo, err := ia.fetchAPIKeyInfo(ctx, cookie)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow cookie authentication: fetch initial API key info failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert to token data format
|
||||
// Refresh the API key using POST request
|
||||
refreshedKeyInfo, err := ia.RefreshAPIKey(ctx, cookie, keyInfo.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iflow cookie authentication: refresh API key failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert to token data format using refreshed key
|
||||
data := &IFlowTokenData{
|
||||
APIKey: keyInfo.APIKey,
|
||||
Expire: keyInfo.ExpireTime,
|
||||
Email: keyInfo.Name,
|
||||
APIKey: refreshedKeyInfo.APIKey,
|
||||
Expire: refreshedKeyInfo.ExpireTime,
|
||||
Email: refreshedKeyInfo.Name,
|
||||
Cookie: cookie,
|
||||
}
|
||||
|
||||
@@ -488,11 +506,18 @@ func (ia *IFlowAuth) CreateCookieTokenStorage(data *IFlowTokenData) *IFlowTokenS
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only save the BXAuth field from the cookie
|
||||
bxAuth := ExtractBXAuth(data.Cookie)
|
||||
cookieToSave := ""
|
||||
if bxAuth != "" {
|
||||
cookieToSave = "BXAuth=" + bxAuth + ";"
|
||||
}
|
||||
|
||||
return &IFlowTokenStorage{
|
||||
APIKey: data.APIKey,
|
||||
Email: data.Email,
|
||||
Expire: data.Expire,
|
||||
Cookie: data.Cookie,
|
||||
Cookie: cookieToSave,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
Type: "iflow",
|
||||
}
|
||||
|
||||
@@ -21,6 +21,15 @@ type IFlowTokenStorage struct {
|
||||
Scope string `json:"scope"`
|
||||
Cookie string `json:"cookie"`
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serialises the token storage to disk.
|
||||
@@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("iflow token: encode token failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
168
internal/auth/kilo/kilo_auth.go
Normal file
168
internal/auth/kilo/kilo_auth.go
Normal file
@@ -0,0 +1,168 @@
|
||||
// Package kilo provides authentication and token management functionality
|
||||
// for Kilo AI services.
|
||||
package kilo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// BaseURL is the base URL for the Kilo AI API.
|
||||
BaseURL = "https://api.kilo.ai/api"
|
||||
)
|
||||
|
||||
// DeviceAuthResponse represents the response from initiating device flow.
|
||||
type DeviceAuthResponse struct {
|
||||
Code string `json:"code"`
|
||||
VerificationURL string `json:"verificationUrl"`
|
||||
ExpiresIn int `json:"expiresIn"`
|
||||
}
|
||||
|
||||
// DeviceStatusResponse represents the response when polling for device flow status.
|
||||
type DeviceStatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
Token string `json:"token"`
|
||||
UserEmail string `json:"userEmail"`
|
||||
}
|
||||
|
||||
// Profile represents the user profile from Kilo AI.
|
||||
type Profile struct {
|
||||
Email string `json:"email"`
|
||||
Orgs []Organization `json:"organizations"`
|
||||
}
|
||||
|
||||
// Organization represents a Kilo AI organization.
|
||||
type Organization struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Defaults represents default settings for an organization or user.
|
||||
type Defaults struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// KiloAuth provides methods for handling the Kilo AI authentication flow.
|
||||
type KiloAuth struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewKiloAuth creates a new instance of KiloAuth.
|
||||
func NewKiloAuth() *KiloAuth {
|
||||
return &KiloAuth{
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// InitiateDeviceFlow starts the device authentication flow.
|
||||
func (k *KiloAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceAuthResponse, error) {
|
||||
resp, err := k.client.Post(BaseURL+"/device-auth/codes", "application/json", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to initiate device flow: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var data DeviceAuthResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// PollForToken polls for the device flow completion.
|
||||
func (k *KiloAuth) PollForToken(ctx context.Context, code string) (*DeviceStatusResponse, error) {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
resp, err := k.client.Get(BaseURL + "/device-auth/codes/" + code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var data DeviceStatusResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch data.Status {
|
||||
case "approved":
|
||||
return &data, nil
|
||||
case "denied", "expired":
|
||||
return nil, fmt.Errorf("device flow %s", data.Status)
|
||||
case "pending":
|
||||
continue
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown status: %s", data.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetProfile fetches the user's profile.
|
||||
func (k *KiloAuth) GetProfile(ctx context.Context, token string) (*Profile, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", BaseURL+"/profile", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create get profile request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := k.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to get profile: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var profile Profile
|
||||
if err := json.NewDecoder(resp.Body).Decode(&profile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
// GetDefaults fetches default settings for an organization.
|
||||
func (k *KiloAuth) GetDefaults(ctx context.Context, token, orgID string) (*Defaults, error) {
|
||||
url := BaseURL + "/defaults"
|
||||
if orgID != "" {
|
||||
url = BaseURL + "/organizations/" + orgID + "/defaults"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create get defaults request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := k.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to get defaults: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var defaults Defaults
|
||||
if err := json.NewDecoder(resp.Body).Decode(&defaults); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &defaults, nil
|
||||
}
|
||||
60
internal/auth/kilo/kilo_token.go
Normal file
60
internal/auth/kilo/kilo_token.go
Normal file
@@ -0,0 +1,60 @@
|
||||
// Package kilo provides authentication and token management functionality
|
||||
// for Kilo AI services.
|
||||
package kilo
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// KiloTokenStorage stores token information for Kilo AI authentication.
|
||||
type KiloTokenStorage struct {
|
||||
// Token is the Kilo access token.
|
||||
Token string `json:"kilocodeToken"`
|
||||
|
||||
// OrganizationID is the Kilo organization ID.
|
||||
OrganizationID string `json:"kilocodeOrganizationId"`
|
||||
|
||||
// Model is the default model to use.
|
||||
Model string `json:"kilocodeModel"`
|
||||
|
||||
// Email is the email address of the authenticated user.
|
||||
Email string `json:"email"`
|
||||
|
||||
// Type indicates the authentication provider type, always "kilo" for this storage.
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Kilo token storage to a JSON file.
|
||||
func (ts *KiloTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "kilo"
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := f.Close(); errClose != nil {
|
||||
log.Errorf("failed to close file: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CredentialFileName returns the filename used to persist Kilo credentials.
|
||||
func CredentialFileName(email string) string {
|
||||
return fmt.Sprintf("kilo-%s.json", email)
|
||||
}
|
||||
396
internal/auth/kimi/kimi.go
Normal file
396
internal/auth/kimi/kimi.go
Normal file
@@ -0,0 +1,396 @@
|
||||
// Package kimi provides authentication and token management for Kimi (Moonshot AI) API.
|
||||
// It handles the RFC 8628 OAuth2 Device Authorization Grant flow for secure authentication.
|
||||
package kimi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// kimiClientID is Kimi Code's OAuth client ID.
|
||||
kimiClientID = "17e5f671-d194-4dfb-9706-5516cb48c098"
|
||||
// kimiOAuthHost is the OAuth server endpoint.
|
||||
kimiOAuthHost = "https://auth.kimi.com"
|
||||
// kimiDeviceCodeURL is the endpoint for requesting device codes.
|
||||
kimiDeviceCodeURL = kimiOAuthHost + "/api/oauth/device_authorization"
|
||||
// kimiTokenURL is the endpoint for exchanging device codes for tokens.
|
||||
kimiTokenURL = kimiOAuthHost + "/api/oauth/token"
|
||||
// KimiAPIBaseURL is the base URL for Kimi API requests.
|
||||
KimiAPIBaseURL = "https://api.kimi.com/coding"
|
||||
// defaultPollInterval is the default interval for polling token endpoint.
|
||||
defaultPollInterval = 5 * time.Second
|
||||
// maxPollDuration is the maximum time to wait for user authorization.
|
||||
maxPollDuration = 15 * time.Minute
|
||||
// refreshThresholdSeconds is when to refresh token before expiry (5 minutes).
|
||||
refreshThresholdSeconds = 300
|
||||
)
|
||||
|
||||
// KimiAuth handles Kimi authentication flow.
|
||||
type KimiAuth struct {
|
||||
deviceClient *DeviceFlowClient
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewKimiAuth creates a new KimiAuth service instance.
|
||||
func NewKimiAuth(cfg *config.Config) *KimiAuth {
|
||||
return &KimiAuth{
|
||||
deviceClient: NewDeviceFlowClient(cfg),
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// StartDeviceFlow initiates the device flow authentication.
|
||||
func (k *KimiAuth) StartDeviceFlow(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||
return k.deviceClient.RequestDeviceCode(ctx)
|
||||
}
|
||||
|
||||
// WaitForAuthorization polls for user authorization and returns the auth bundle.
|
||||
func (k *KimiAuth) WaitForAuthorization(ctx context.Context, deviceCode *DeviceCodeResponse) (*KimiAuthBundle, error) {
|
||||
tokenData, err := k.deviceClient.PollForToken(ctx, deviceCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &KimiAuthBundle{
|
||||
TokenData: tokenData,
|
||||
DeviceID: k.deviceClient.deviceID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateTokenStorage creates a new KimiTokenStorage from auth bundle.
|
||||
func (k *KimiAuth) CreateTokenStorage(bundle *KimiAuthBundle) *KimiTokenStorage {
|
||||
expired := ""
|
||||
if bundle.TokenData.ExpiresAt > 0 {
|
||||
expired = time.Unix(bundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
return &KimiTokenStorage{
|
||||
AccessToken: bundle.TokenData.AccessToken,
|
||||
RefreshToken: bundle.TokenData.RefreshToken,
|
||||
TokenType: bundle.TokenData.TokenType,
|
||||
Scope: bundle.TokenData.Scope,
|
||||
DeviceID: strings.TrimSpace(bundle.DeviceID),
|
||||
Expired: expired,
|
||||
Type: "kimi",
|
||||
}
|
||||
}
|
||||
|
||||
// DeviceFlowClient handles the OAuth2 device flow for Kimi.
|
||||
type DeviceFlowClient struct {
|
||||
httpClient *http.Client
|
||||
cfg *config.Config
|
||||
deviceID string
|
||||
}
|
||||
|
||||
// NewDeviceFlowClient creates a new device flow client.
|
||||
func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
|
||||
return NewDeviceFlowClientWithDeviceID(cfg, "")
|
||||
}
|
||||
|
||||
// NewDeviceFlowClientWithDeviceID creates a new device flow client with the specified device ID.
|
||||
func NewDeviceFlowClientWithDeviceID(cfg *config.Config, deviceID string) *DeviceFlowClient {
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
resolvedDeviceID := strings.TrimSpace(deviceID)
|
||||
if resolvedDeviceID == "" {
|
||||
resolvedDeviceID = getOrCreateDeviceID()
|
||||
}
|
||||
return &DeviceFlowClient{
|
||||
httpClient: client,
|
||||
cfg: cfg,
|
||||
deviceID: resolvedDeviceID,
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateDeviceID returns an in-memory device ID for the current authentication flow.
|
||||
func getOrCreateDeviceID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// getDeviceModel returns a device model string.
|
||||
func getDeviceModel() string {
|
||||
osName := runtime.GOOS
|
||||
arch := runtime.GOARCH
|
||||
|
||||
switch osName {
|
||||
case "darwin":
|
||||
return fmt.Sprintf("macOS %s", arch)
|
||||
case "windows":
|
||||
return fmt.Sprintf("Windows %s", arch)
|
||||
case "linux":
|
||||
return fmt.Sprintf("Linux %s", arch)
|
||||
default:
|
||||
return fmt.Sprintf("%s %s", osName, arch)
|
||||
}
|
||||
}
|
||||
|
||||
// getHostname returns the machine hostname.
|
||||
func getHostname() string {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return "unknown"
|
||||
}
|
||||
return hostname
|
||||
}
|
||||
|
||||
// commonHeaders returns headers required for Kimi API requests.
|
||||
func (c *DeviceFlowClient) commonHeaders() map[string]string {
|
||||
return map[string]string{
|
||||
"X-Msh-Platform": "cli-proxy-api",
|
||||
"X-Msh-Version": "1.0.0",
|
||||
"X-Msh-Device-Name": getHostname(),
|
||||
"X-Msh-Device-Model": getDeviceModel(),
|
||||
"X-Msh-Device-Id": c.deviceID,
|
||||
}
|
||||
}
|
||||
|
||||
// RequestDeviceCode initiates the device flow by requesting a device code from Kimi.
|
||||
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kimiClientID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiDeviceCodeURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to create device code request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
for k, v := range c.commonHeaders() {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: device code request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi device code: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to read device code response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("kimi: device code request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var deviceCode DeviceCodeResponse
|
||||
if err = json.Unmarshal(bodyBytes, &deviceCode); err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to parse device code response: %w", err)
|
||||
}
|
||||
|
||||
return &deviceCode, nil
|
||||
}
|
||||
|
||||
// PollForToken polls the token endpoint until the user authorizes or the device code expires.
|
||||
func (c *DeviceFlowClient) PollForToken(ctx context.Context, deviceCode *DeviceCodeResponse) (*KimiTokenData, error) {
|
||||
if deviceCode == nil {
|
||||
return nil, fmt.Errorf("kimi: device code is nil")
|
||||
}
|
||||
|
||||
interval := time.Duration(deviceCode.Interval) * time.Second
|
||||
if interval < defaultPollInterval {
|
||||
interval = defaultPollInterval
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(maxPollDuration)
|
||||
if deviceCode.ExpiresIn > 0 {
|
||||
codeDeadline := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second)
|
||||
if codeDeadline.Before(deadline) {
|
||||
deadline = codeDeadline
|
||||
}
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("kimi: context cancelled: %w", ctx.Err())
|
||||
case <-ticker.C:
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("kimi: device code expired")
|
||||
}
|
||||
|
||||
token, pollErr, shouldContinue := c.exchangeDeviceCode(ctx, deviceCode.DeviceCode)
|
||||
if token != nil {
|
||||
return token, nil
|
||||
}
|
||||
if !shouldContinue {
|
||||
return nil, pollErr
|
||||
}
|
||||
// Continue polling
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// exchangeDeviceCode attempts to exchange the device code for an access token.
|
||||
// Returns (token, error, shouldContinue).
|
||||
func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode string) (*KimiTokenData, error, bool) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kimiClientID)
|
||||
data.Set("device_code", deviceCode)
|
||||
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiTokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to create token request: %w", err), false
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
for k, v := range c.commonHeaders() {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: token request failed: %w", err), false
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi token exchange: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to read token response: %w", err), false
|
||||
}
|
||||
|
||||
// Parse response - Kimi returns 200 for both success and pending states
|
||||
var oauthResp struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn float64 `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(bodyBytes, &oauthResp); err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to parse token response: %w", err), false
|
||||
}
|
||||
|
||||
if oauthResp.Error != "" {
|
||||
switch oauthResp.Error {
|
||||
case "authorization_pending":
|
||||
return nil, nil, true // Continue polling
|
||||
case "slow_down":
|
||||
return nil, nil, true // Continue polling (with increased interval handled by caller)
|
||||
case "expired_token":
|
||||
return nil, fmt.Errorf("kimi: device code expired"), false
|
||||
case "access_denied":
|
||||
return nil, fmt.Errorf("kimi: access denied by user"), false
|
||||
default:
|
||||
return nil, fmt.Errorf("kimi: OAuth error: %s - %s", oauthResp.Error, oauthResp.ErrorDescription), false
|
||||
}
|
||||
}
|
||||
|
||||
if oauthResp.AccessToken == "" {
|
||||
return nil, fmt.Errorf("kimi: empty access token in response"), false
|
||||
}
|
||||
|
||||
var expiresAt int64
|
||||
if oauthResp.ExpiresIn > 0 {
|
||||
expiresAt = time.Now().Unix() + int64(oauthResp.ExpiresIn)
|
||||
}
|
||||
|
||||
return &KimiTokenData{
|
||||
AccessToken: oauthResp.AccessToken,
|
||||
RefreshToken: oauthResp.RefreshToken,
|
||||
TokenType: oauthResp.TokenType,
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: oauthResp.Scope,
|
||||
}, nil, false
|
||||
}
|
||||
|
||||
// RefreshToken exchanges a refresh token for a new access token.
|
||||
func (c *DeviceFlowClient) RefreshToken(ctx context.Context, refreshToken string) (*KimiTokenData, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kimiClientID)
|
||||
data.Set("grant_type", "refresh_token")
|
||||
data.Set("refresh_token", refreshToken)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiTokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to create refresh request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
for k, v := range c.commonHeaders() {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: refresh request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi refresh token: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to read refresh response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
return nil, fmt.Errorf("kimi: refresh token rejected (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("kimi: refresh failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn float64 `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(bodyBytes, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" {
|
||||
return nil, fmt.Errorf("kimi: empty access token in refresh response")
|
||||
}
|
||||
|
||||
var expiresAt int64
|
||||
if tokenResp.ExpiresIn > 0 {
|
||||
expiresAt = time.Now().Unix() + int64(tokenResp.ExpiresIn)
|
||||
}
|
||||
|
||||
return &KimiTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: tokenResp.Scope,
|
||||
}, nil
|
||||
}
|
||||
131
internal/auth/kimi/token.go
Normal file
131
internal/auth/kimi/token.go
Normal file
@@ -0,0 +1,131 @@
|
||||
// Package kimi provides authentication and token management functionality
|
||||
// for Kimi (Moonshot AI) services. It handles OAuth2 device flow token storage,
|
||||
// serialization, and retrieval for maintaining authenticated sessions with the Kimi API.
|
||||
package kimi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
)
|
||||
|
||||
// KimiTokenStorage stores OAuth2 token information for Kimi API authentication.
|
||||
type KimiTokenStorage struct {
|
||||
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is the OAuth2 refresh token used to obtain new access tokens.
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// TokenType is the type of token, typically "Bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// Scope is the OAuth2 scope granted to the token.
|
||||
Scope string `json:"scope,omitempty"`
|
||||
// DeviceID is the OAuth device flow identifier used for Kimi requests.
|
||||
DeviceID string `json:"device_id,omitempty"`
|
||||
// Expired is the RFC3339 timestamp when the access token expires.
|
||||
Expired string `json:"expired,omitempty"`
|
||||
// Type indicates the authentication provider type, always "kimi" for this storage.
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// KimiTokenData holds the raw OAuth token response from Kimi.
|
||||
type KimiTokenData struct {
|
||||
// AccessToken is the OAuth2 access token.
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is the OAuth2 refresh token.
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// TokenType is the type of token, typically "Bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// ExpiresAt is the Unix timestamp when the token expires.
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
// Scope is the OAuth2 scope granted to the token.
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// KimiAuthBundle bundles authentication data for storage.
|
||||
type KimiAuthBundle struct {
|
||||
// TokenData contains the OAuth token information.
|
||||
TokenData *KimiTokenData
|
||||
// DeviceID is the device identifier used during OAuth device flow.
|
||||
DeviceID string
|
||||
}
|
||||
|
||||
// DeviceCodeResponse represents Kimi's device code response.
|
||||
type DeviceCodeResponse struct {
|
||||
// DeviceCode is the device verification code.
|
||||
DeviceCode string `json:"device_code"`
|
||||
// UserCode is the code the user must enter at the verification URI.
|
||||
UserCode string `json:"user_code"`
|
||||
// VerificationURI is the URL where the user should enter the code.
|
||||
VerificationURI string `json:"verification_uri,omitempty"`
|
||||
// VerificationURIComplete is the URL with the code pre-filled.
|
||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||
// ExpiresIn is the number of seconds until the device code expires.
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
// Interval is the minimum number of seconds to wait between polling requests.
|
||||
Interval int `json:"interval"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Kimi token storage to a JSON file.
|
||||
func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "kimi"
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
encoder := json.NewEncoder(f)
|
||||
encoder.SetIndent("", " ")
|
||||
if err = encoder.Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsExpired checks if the token has expired.
|
||||
func (ts *KimiTokenStorage) IsExpired() bool {
|
||||
if ts.Expired == "" {
|
||||
return false // No expiry set, assume valid
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, ts.Expired)
|
||||
if err != nil {
|
||||
return true // Has expiry string but can't parse
|
||||
}
|
||||
// Consider expired if within refresh threshold
|
||||
return time.Now().Add(time.Duration(refreshThresholdSeconds) * time.Second).After(t)
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the token should be refreshed.
|
||||
func (ts *KimiTokenStorage) NeedsRefresh() bool {
|
||||
if ts.RefreshToken == "" {
|
||||
return false // Can't refresh without refresh token
|
||||
}
|
||||
return ts.IsExpired()
|
||||
}
|
||||
681
internal/auth/kiro/aws.go
Normal file
681
internal/auth/kiro/aws.go
Normal file
@@ -0,0 +1,681 @@
|
||||
// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
|
||||
// It includes interfaces and implementations for token storage and authentication methods.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
|
||||
type PKCECodes struct {
|
||||
// CodeVerifier is the cryptographically random string used to correlate
|
||||
// the authorization request to the token request
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
// CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
}
|
||||
|
||||
// KiroTokenData holds OAuth token information from AWS CodeWhisperer (Kiro)
|
||||
type KiroTokenData struct {
|
||||
// AccessToken is the OAuth2 access token for API access
|
||||
AccessToken string `json:"accessToken"`
|
||||
// RefreshToken is used to obtain new access tokens
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
// ProfileArn is the AWS CodeWhisperer profile ARN
|
||||
ProfileArn string `json:"profileArn"`
|
||||
// ExpiresAt is the timestamp when the token expires
|
||||
ExpiresAt string `json:"expiresAt"`
|
||||
// AuthMethod indicates the authentication method used (e.g., "builder-id", "social", "idc")
|
||||
AuthMethod string `json:"authMethod"`
|
||||
// Provider indicates the OAuth provider (e.g., "AWS", "Google", "Enterprise")
|
||||
Provider string `json:"provider"`
|
||||
// ClientID is the OIDC client ID (needed for token refresh)
|
||||
ClientID string `json:"clientId,omitempty"`
|
||||
// ClientSecret is the OIDC client secret (needed for token refresh)
|
||||
ClientSecret string `json:"clientSecret,omitempty"`
|
||||
// ClientIDHash is the hash of client ID used to locate device registration file
|
||||
// (Enterprise Kiro IDE stores clientId/clientSecret in ~/.aws/sso/cache/{clientIdHash}.json)
|
||||
ClientIDHash string `json:"clientIdHash,omitempty"`
|
||||
// Email is the user's email address (used for file naming)
|
||||
Email string `json:"email,omitempty"`
|
||||
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
|
||||
StartURL string `json:"startUrl,omitempty"`
|
||||
// Region is the OIDC region for IDC login and token refresh
|
||||
Region string `json:"region,omitempty"`
|
||||
}
|
||||
|
||||
// KiroAuthBundle aggregates authentication data after OAuth flow completion
|
||||
type KiroAuthBundle struct {
|
||||
// TokenData contains the OAuth tokens from the authentication flow
|
||||
TokenData KiroTokenData `json:"token_data"`
|
||||
// LastRefresh is the timestamp of the last token refresh
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
}
|
||||
|
||||
// KiroUsageInfo represents usage information from CodeWhisperer API
|
||||
type KiroUsageInfo struct {
|
||||
// SubscriptionTitle is the subscription plan name (e.g., "KIRO FREE")
|
||||
SubscriptionTitle string `json:"subscription_title"`
|
||||
// CurrentUsage is the current credit usage
|
||||
CurrentUsage float64 `json:"current_usage"`
|
||||
// UsageLimit is the maximum credit limit
|
||||
UsageLimit float64 `json:"usage_limit"`
|
||||
// NextReset is the timestamp of the next usage reset
|
||||
NextReset string `json:"next_reset"`
|
||||
}
|
||||
|
||||
// KiroModel represents a model available through the CodeWhisperer API
|
||||
type KiroModel struct {
|
||||
// ModelID is the unique identifier for the model
|
||||
ModelID string `json:"modelId"`
|
||||
// ModelName is the human-readable name
|
||||
ModelName string `json:"modelName"`
|
||||
// Description is the model description
|
||||
Description string `json:"description"`
|
||||
// RateMultiplier is the credit multiplier for this model
|
||||
RateMultiplier float64 `json:"rateMultiplier"`
|
||||
// RateUnit is the unit for rate calculation (e.g., "credit")
|
||||
RateUnit string `json:"rateUnit"`
|
||||
// MaxInputTokens is the maximum input token limit
|
||||
MaxInputTokens int `json:"maxInputTokens,omitempty"`
|
||||
}
|
||||
|
||||
// KiroIDETokenFile is the default path to Kiro IDE's token file
|
||||
const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json"
|
||||
|
||||
// Default retry configuration for file reading
|
||||
const (
|
||||
defaultTokenReadMaxAttempts = 10 // Maximum retry attempts
|
||||
defaultTokenReadBaseDelay = 50 * time.Millisecond // Base delay between retries
|
||||
)
|
||||
|
||||
// isTransientFileError checks if the error is a transient file access error
|
||||
// that may be resolved by retrying (e.g., file locked by another process on Windows).
|
||||
func isTransientFileError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for OS-level file access errors (Windows sharing violation, etc.)
|
||||
var pathErr *os.PathError
|
||||
if errors.As(err, &pathErr) {
|
||||
// Windows sharing violation (ERROR_SHARING_VIOLATION = 32)
|
||||
// Windows lock violation (ERROR_LOCK_VIOLATION = 33)
|
||||
errStr := pathErr.Err.Error()
|
||||
if strings.Contains(errStr, "being used by another process") ||
|
||||
strings.Contains(errStr, "sharing violation") ||
|
||||
strings.Contains(errStr, "lock violation") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check error message for common transient patterns
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
transientPatterns := []string{
|
||||
"being used by another process",
|
||||
"sharing violation",
|
||||
"lock violation",
|
||||
"access is denied",
|
||||
"unexpected end of json",
|
||||
"unexpected eof",
|
||||
}
|
||||
for _, pattern := range transientPatterns {
|
||||
if strings.Contains(errMsg, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// LoadKiroIDETokenWithRetry loads token data from Kiro IDE's token file with retry logic.
|
||||
// This handles transient file access errors (e.g., file locked by Kiro IDE during write).
|
||||
// maxAttempts: maximum number of retry attempts (default 10 if <= 0)
|
||||
// baseDelay: base delay between retries with exponential backoff (default 50ms if <= 0)
|
||||
func LoadKiroIDETokenWithRetry(maxAttempts int, baseDelay time.Duration) (*KiroTokenData, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = defaultTokenReadMaxAttempts
|
||||
}
|
||||
if baseDelay <= 0 {
|
||||
baseDelay = defaultTokenReadBaseDelay
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
token, err := LoadKiroIDEToken()
|
||||
if err == nil {
|
||||
return token, nil
|
||||
}
|
||||
lastErr = err
|
||||
|
||||
// Only retry for transient errors
|
||||
if !isTransientFileError(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Exponential backoff: delay * 2^attempt, capped at 500ms
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt))
|
||||
if delay > 500*time.Millisecond {
|
||||
delay = 500 * time.Millisecond
|
||||
}
|
||||
time.Sleep(delay)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to read token file after %d attempts: %w", maxAttempts, lastErr)
|
||||
}
|
||||
|
||||
// LoadKiroIDEToken loads token data from Kiro IDE's token file.
|
||||
// For Enterprise Kiro IDE (IDC auth), it also loads clientId and clientSecret
|
||||
// from the device registration file referenced by clientIdHash.
|
||||
func LoadKiroIDEToken() (*KiroTokenData, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
|
||||
tokenPath := filepath.Join(homeDir, KiroIDETokenFile)
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read Kiro IDE token file (%s): %w", tokenPath, err)
|
||||
}
|
||||
|
||||
var token KiroTokenData
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse Kiro IDE token: %w", err)
|
||||
}
|
||||
|
||||
if token.AccessToken == "" {
|
||||
return nil, fmt.Errorf("access token is empty in Kiro IDE token file")
|
||||
}
|
||||
|
||||
// Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
|
||||
token.AuthMethod = strings.ToLower(token.AuthMethod)
|
||||
|
||||
// For Enterprise Kiro IDE (IDC auth), load clientId and clientSecret from device registration
|
||||
// The device registration file is located at ~/.aws/sso/cache/{clientIdHash}.json
|
||||
if token.ClientIDHash != "" && token.ClientID == "" {
|
||||
if err := loadDeviceRegistration(homeDir, token.ClientIDHash, &token); err != nil {
|
||||
// Log warning but don't fail - token might still work for some operations
|
||||
fmt.Printf("warning: failed to load device registration for clientIdHash %s: %v\n", token.ClientIDHash, err)
|
||||
}
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// loadDeviceRegistration loads clientId and clientSecret from the device registration file.
|
||||
// Enterprise Kiro IDE stores these in ~/.aws/sso/cache/{clientIdHash}.json
|
||||
func loadDeviceRegistration(homeDir, clientIDHash string, token *KiroTokenData) error {
|
||||
if clientIDHash == "" {
|
||||
return fmt.Errorf("clientIdHash is empty")
|
||||
}
|
||||
|
||||
// Sanitize clientIdHash to prevent path traversal
|
||||
if strings.Contains(clientIDHash, "/") || strings.Contains(clientIDHash, "\\") || strings.Contains(clientIDHash, "..") {
|
||||
return fmt.Errorf("invalid clientIdHash: contains path separator")
|
||||
}
|
||||
|
||||
deviceRegPath := filepath.Join(homeDir, ".aws", "sso", "cache", clientIDHash+".json")
|
||||
data, err := os.ReadFile(deviceRegPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read device registration file (%s): %w", deviceRegPath, err)
|
||||
}
|
||||
|
||||
// Device registration file structure
|
||||
var deviceReg struct {
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
ExpiresAt string `json:"expiresAt"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &deviceReg); err != nil {
|
||||
return fmt.Errorf("failed to parse device registration: %w", err)
|
||||
}
|
||||
|
||||
if deviceReg.ClientID == "" || deviceReg.ClientSecret == "" {
|
||||
return fmt.Errorf("device registration missing clientId or clientSecret")
|
||||
}
|
||||
|
||||
token.ClientID = deviceReg.ClientID
|
||||
token.ClientSecret = deviceReg.ClientSecret
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadKiroTokenFromPath loads token data from a custom path.
|
||||
// This supports multiple accounts by allowing different token files.
|
||||
// For Enterprise Kiro IDE (IDC auth), it also loads clientId and clientSecret
|
||||
// from the device registration file referenced by clientIdHash.
|
||||
func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
|
||||
// Expand ~ to home directory
|
||||
if len(tokenPath) > 0 && tokenPath[0] == '~' {
|
||||
tokenPath = filepath.Join(homeDir, tokenPath[1:])
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token file (%s): %w", tokenPath, err)
|
||||
}
|
||||
|
||||
var token KiroTokenData
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||
}
|
||||
|
||||
if token.AccessToken == "" {
|
||||
return nil, fmt.Errorf("access token is empty in token file")
|
||||
}
|
||||
|
||||
// Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
|
||||
token.AuthMethod = strings.ToLower(token.AuthMethod)
|
||||
|
||||
// For Enterprise Kiro IDE (IDC auth), load clientId and clientSecret from device registration
|
||||
if token.ClientIDHash != "" && token.ClientID == "" {
|
||||
if err := loadDeviceRegistration(homeDir, token.ClientIDHash, &token); err != nil {
|
||||
// Log warning but don't fail - token might still work for some operations
|
||||
fmt.Printf("warning: failed to load device registration for clientIdHash %s: %v\n", token.ClientIDHash, err)
|
||||
}
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// ListKiroTokenFiles lists all Kiro token files in the cache directory.
|
||||
// This supports multiple accounts by finding all token files.
|
||||
func ListKiroTokenFiles() ([]string, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
|
||||
cacheDir := filepath.Join(homeDir, ".aws", "sso", "cache")
|
||||
|
||||
// Check if directory exists
|
||||
if _, err := os.Stat(cacheDir); os.IsNotExist(err) {
|
||||
return nil, nil // No token files
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(cacheDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read cache directory: %w", err)
|
||||
}
|
||||
|
||||
var tokenFiles []string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
// Look for kiro token files only (avoid matching unrelated AWS SSO cache files)
|
||||
if strings.HasSuffix(name, ".json") && strings.HasPrefix(name, "kiro") {
|
||||
tokenFiles = append(tokenFiles, filepath.Join(cacheDir, name))
|
||||
}
|
||||
}
|
||||
|
||||
return tokenFiles, nil
|
||||
}
|
||||
|
||||
// LoadAllKiroTokens loads all Kiro tokens from the cache directory.
|
||||
// This supports multiple accounts.
|
||||
func LoadAllKiroTokens() ([]*KiroTokenData, error) {
|
||||
files, err := ListKiroTokenFiles()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tokens []*KiroTokenData
|
||||
for _, file := range files {
|
||||
token, err := LoadKiroTokenFromPath(file)
|
||||
if err != nil {
|
||||
// Skip invalid token files
|
||||
continue
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// JWTClaims represents the claims we care about from a JWT token.
|
||||
// JWT tokens from Kiro/AWS contain user information in the payload.
|
||||
type JWTClaims struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
Sub string `json:"sub,omitempty"`
|
||||
PreferredUser string `json:"preferred_username,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Iss string `json:"iss,omitempty"`
|
||||
}
|
||||
|
||||
// ExtractEmailFromJWT extracts the user's email from a JWT access token.
|
||||
// JWT tokens typically have format: header.payload.signature
|
||||
// The payload is base64url-encoded JSON containing user claims.
|
||||
func ExtractEmailFromJWT(accessToken string) string {
|
||||
if accessToken == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// JWT format: header.payload.signature
|
||||
parts := strings.Split(accessToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Decode the payload (second part)
|
||||
payload := parts[1]
|
||||
|
||||
// Add padding if needed (base64url requires padding)
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
payload += "=="
|
||||
case 3:
|
||||
payload += "="
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
// Try RawURLEncoding (no padding)
|
||||
decoded, err = base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
var claims JWTClaims
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Return email if available
|
||||
if claims.Email != "" {
|
||||
return claims.Email
|
||||
}
|
||||
|
||||
// Fallback to preferred_username (some providers use this)
|
||||
if claims.PreferredUser != "" && strings.Contains(claims.PreferredUser, "@") {
|
||||
return claims.PreferredUser
|
||||
}
|
||||
|
||||
// Fallback to sub if it looks like an email
|
||||
if claims.Sub != "" && strings.Contains(claims.Sub, "@") {
|
||||
return claims.Sub
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// SanitizeEmailForFilename sanitizes an email address for use in a filename.
|
||||
// Replaces special characters with underscores and prevents path traversal attacks.
|
||||
// Also handles URL-encoded characters to prevent encoded path traversal attempts.
|
||||
func SanitizeEmailForFilename(email string) string {
|
||||
if email == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
result := email
|
||||
|
||||
// First, handle URL-encoded path traversal attempts (%2F, %2E, %5C, etc.)
|
||||
// This prevents encoded characters from bypassing the sanitization.
|
||||
// Note: We replace % last to catch any remaining encodings including double-encoding (%252F)
|
||||
result = strings.ReplaceAll(result, "%2F", "_") // /
|
||||
result = strings.ReplaceAll(result, "%2f", "_")
|
||||
result = strings.ReplaceAll(result, "%5C", "_") // \
|
||||
result = strings.ReplaceAll(result, "%5c", "_")
|
||||
result = strings.ReplaceAll(result, "%2E", "_") // .
|
||||
result = strings.ReplaceAll(result, "%2e", "_")
|
||||
result = strings.ReplaceAll(result, "%00", "_") // null byte
|
||||
result = strings.ReplaceAll(result, "%", "_") // Catch remaining % to prevent double-encoding attacks
|
||||
|
||||
// Replace characters that are problematic in filenames
|
||||
// Keep @ and . in middle but replace other special characters
|
||||
for _, char := range []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|", " ", "\x00"} {
|
||||
result = strings.ReplaceAll(result, char, "_")
|
||||
}
|
||||
|
||||
// Prevent path traversal: replace leading dots in each path component
|
||||
// This handles cases like "../../../etc/passwd" → "_.._.._.._etc_passwd"
|
||||
parts := strings.Split(result, "_")
|
||||
for i, part := range parts {
|
||||
for strings.HasPrefix(part, ".") {
|
||||
part = "_" + part[1:]
|
||||
}
|
||||
parts[i] = part
|
||||
}
|
||||
result = strings.Join(parts, "_")
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ExtractIDCIdentifier extracts a unique identifier from IDC startUrl.
|
||||
// Examples:
|
||||
// - "https://d-1234567890.awsapps.com/start" -> "d-1234567890"
|
||||
// - "https://my-company.awsapps.com/start" -> "my-company"
|
||||
// - "https://acme-corp.awsapps.com/start" -> "acme-corp"
|
||||
func ExtractIDCIdentifier(startURL string) string {
|
||||
if startURL == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Remove protocol prefix
|
||||
url := strings.TrimPrefix(startURL, "https://")
|
||||
url = strings.TrimPrefix(url, "http://")
|
||||
|
||||
// Extract subdomain (first part before the first dot)
|
||||
// Format: {identifier}.awsapps.com/start
|
||||
parts := strings.Split(url, ".")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
identifier := parts[0]
|
||||
// Sanitize for filename safety
|
||||
identifier = strings.ReplaceAll(identifier, "/", "_")
|
||||
identifier = strings.ReplaceAll(identifier, "\\", "_")
|
||||
identifier = strings.ReplaceAll(identifier, ":", "_")
|
||||
return identifier
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// GenerateTokenFileName generates a unique filename for token storage.
|
||||
// Priority: email > startUrl identifier (for IDC) > authMethod only
|
||||
// Email is unique, so no sequence suffix needed. Sequence is only added
|
||||
// when email is unavailable to prevent filename collisions.
|
||||
// Format: kiro-{authMethod}-{identifier}[-{seq}].json
|
||||
func GenerateTokenFileName(tokenData *KiroTokenData) string {
|
||||
authMethod := tokenData.AuthMethod
|
||||
if authMethod == "" {
|
||||
authMethod = "unknown"
|
||||
}
|
||||
|
||||
// Priority 1: Use email if available (no sequence needed, email is unique)
|
||||
if tokenData.Email != "" {
|
||||
// Sanitize email for filename (replace @ and . with -)
|
||||
sanitizedEmail := tokenData.Email
|
||||
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, "@", "-")
|
||||
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, ".", "-")
|
||||
return fmt.Sprintf("kiro-%s-%s.json", authMethod, sanitizedEmail)
|
||||
}
|
||||
|
||||
// Generate sequence only when email is unavailable
|
||||
seq := time.Now().UnixNano() % 100000
|
||||
|
||||
// Priority 2: For IDC, use startUrl identifier with sequence
|
||||
if authMethod == "idc" && tokenData.StartURL != "" {
|
||||
identifier := ExtractIDCIdentifier(tokenData.StartURL)
|
||||
if identifier != "" {
|
||||
return fmt.Sprintf("kiro-%s-%s-%05d.json", authMethod, identifier, seq)
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 3: Fallback to authMethod only with sequence
|
||||
return fmt.Sprintf("kiro-%s-%05d.json", authMethod, seq)
|
||||
}
|
||||
|
||||
// DefaultKiroRegion is the fallback region when none is specified.
|
||||
const DefaultKiroRegion = "us-east-1"
|
||||
|
||||
// GetCodeWhispererLegacyEndpoint returns the legacy CodeWhisperer JSON-RPC endpoint.
|
||||
// This endpoint supports JSON-RPC style requests with x-amz-target headers.
|
||||
// The Q endpoint (q.{region}.amazonaws.com) does NOT support JSON-RPC style.
|
||||
func GetCodeWhispererLegacyEndpoint(region string) string {
|
||||
if region == "" {
|
||||
region = DefaultKiroRegion
|
||||
}
|
||||
return "https://codewhisperer." + region + ".amazonaws.com"
|
||||
}
|
||||
|
||||
// ProfileARN represents a parsed AWS CodeWhisperer profile ARN.
|
||||
// ARN format: arn:partition:service:region:account-id:resource-type/resource-id
|
||||
// Example: arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL
|
||||
type ProfileARN struct {
|
||||
// Raw is the original ARN string
|
||||
Raw string
|
||||
// Partition is the AWS partition (aws)
|
||||
Partition string
|
||||
// Service is the AWS service name (codewhisperer)
|
||||
Service string
|
||||
// Region is the AWS region (us-east-1, ap-southeast-1, etc.)
|
||||
Region string
|
||||
// AccountID is the AWS account ID
|
||||
AccountID string
|
||||
// ResourceType is the resource type (profile)
|
||||
ResourceType string
|
||||
// ResourceID is the resource identifier (e.g., ABCDEFGHIJKL)
|
||||
ResourceID string
|
||||
}
|
||||
|
||||
// ParseProfileARN parses an AWS ARN string into a ProfileARN struct.
|
||||
// Returns nil if the ARN is empty, invalid, or not a codewhisperer ARN.
|
||||
func ParseProfileARN(arn string) *ProfileARN {
|
||||
if arn == "" {
|
||||
return nil
|
||||
}
|
||||
// ARN format: arn:partition:service:region:account-id:resource
|
||||
// Minimum 6 parts separated by ":"
|
||||
parts := strings.Split(arn, ":")
|
||||
if len(parts) < 6 {
|
||||
log.Warnf("invalid ARN format: %s", arn)
|
||||
return nil
|
||||
}
|
||||
// Validate ARN prefix
|
||||
if parts[0] != "arn" {
|
||||
return nil
|
||||
}
|
||||
// Validate partition
|
||||
partition := parts[1]
|
||||
if partition == "" {
|
||||
return nil
|
||||
}
|
||||
// Validate service is codewhisperer
|
||||
service := parts[2]
|
||||
if service != "codewhisperer" {
|
||||
return nil
|
||||
}
|
||||
// Validate region format (must contain "-")
|
||||
region := parts[3]
|
||||
if region == "" || !strings.Contains(region, "-") {
|
||||
return nil
|
||||
}
|
||||
// Account ID
|
||||
accountID := parts[4]
|
||||
|
||||
// Parse resource (format: resource-type/resource-id)
|
||||
// Join remaining parts in case resource contains ":"
|
||||
resource := strings.Join(parts[5:], ":")
|
||||
resourceType := ""
|
||||
resourceID := ""
|
||||
if idx := strings.Index(resource, "/"); idx > 0 {
|
||||
resourceType = resource[:idx]
|
||||
resourceID = resource[idx+1:]
|
||||
} else {
|
||||
resourceType = resource
|
||||
}
|
||||
|
||||
return &ProfileARN{
|
||||
Raw: arn,
|
||||
Partition: partition,
|
||||
Service: service,
|
||||
Region: region,
|
||||
AccountID: accountID,
|
||||
ResourceType: resourceType,
|
||||
ResourceID: resourceID,
|
||||
}
|
||||
}
|
||||
|
||||
// GetKiroAPIEndpoint returns the Q API endpoint for the specified region.
|
||||
// If region is empty, defaults to us-east-1.
|
||||
func GetKiroAPIEndpoint(region string) string {
|
||||
if region == "" {
|
||||
region = DefaultKiroRegion
|
||||
}
|
||||
return "https://q." + region + ".amazonaws.com"
|
||||
}
|
||||
|
||||
// GetKiroAPIEndpointFromProfileArn extracts region from profileArn and returns the endpoint.
|
||||
// Returns default us-east-1 endpoint if region cannot be extracted.
|
||||
func GetKiroAPIEndpointFromProfileArn(profileArn string) string {
|
||||
region := ExtractRegionFromProfileArn(profileArn)
|
||||
return GetKiroAPIEndpoint(region)
|
||||
}
|
||||
|
||||
// ExtractRegionFromProfileArn extracts the AWS region from a ProfileARN string.
|
||||
// Returns empty string if ARN is invalid or region cannot be extracted.
|
||||
func ExtractRegionFromProfileArn(profileArn string) string {
|
||||
parsed := ParseProfileARN(profileArn)
|
||||
if parsed == nil {
|
||||
return ""
|
||||
}
|
||||
return parsed.Region
|
||||
}
|
||||
|
||||
// ExtractRegionFromMetadata extracts API region from auth metadata.
|
||||
// Priority: api_region > profile_arn > DefaultKiroRegion
|
||||
func ExtractRegionFromMetadata(metadata map[string]interface{}) string {
|
||||
if metadata == nil {
|
||||
return DefaultKiroRegion
|
||||
}
|
||||
|
||||
// Priority 1: Explicit api_region override
|
||||
if r, ok := metadata["api_region"].(string); ok && r != "" {
|
||||
return r
|
||||
}
|
||||
|
||||
// Priority 2: Extract from ProfileARN
|
||||
if profileArn, ok := metadata["profile_arn"].(string); ok && profileArn != "" {
|
||||
if region := ExtractRegionFromProfileArn(profileArn); region != "" {
|
||||
return region
|
||||
}
|
||||
}
|
||||
|
||||
return DefaultKiroRegion
|
||||
}
|
||||
|
||||
func buildURL(endpoint, path string, queryParams map[string]string) string {
|
||||
fullURL := fmt.Sprintf("%s/%s", endpoint, path)
|
||||
if len(queryParams) > 0 {
|
||||
values := url.Values{}
|
||||
for key, value := range queryParams {
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
values.Set(key, value)
|
||||
}
|
||||
if encoded := values.Encode(); encoded != "" {
|
||||
fullURL = fullURL + "?" + encoded
|
||||
}
|
||||
}
|
||||
return fullURL
|
||||
}
|
||||
326
internal/auth/kiro/aws_auth.go
Normal file
326
internal/auth/kiro/aws_auth.go
Normal file
@@ -0,0 +1,326 @@
|
||||
// Package kiro provides OAuth2 authentication functionality for AWS CodeWhisperer (Kiro) API.
|
||||
// This package implements token loading, refresh, and API communication with CodeWhisperer.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
pathGetUsageLimits = "getUsageLimits"
|
||||
pathListAvailableModels = "ListAvailableModels"
|
||||
)
|
||||
|
||||
// KiroAuth handles AWS CodeWhisperer authentication and API communication.
|
||||
// It provides methods for loading tokens, refreshing expired tokens,
|
||||
// and communicating with the CodeWhisperer API.
|
||||
type KiroAuth struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewKiroAuth creates a new Kiro authentication service.
|
||||
// It initializes the HTTP client with proxy settings from the configuration.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration containing proxy settings
|
||||
//
|
||||
// Returns:
|
||||
// - *KiroAuth: A new Kiro authentication service instance
|
||||
func NewKiroAuth(cfg *config.Config) *KiroAuth {
|
||||
return &KiroAuth{
|
||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 120 * time.Second}),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadTokenFromFile loads token data from a file path.
|
||||
// This method reads and parses the token file, expanding ~ to the home directory.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenFile: Path to the token file (supports ~ expansion)
|
||||
//
|
||||
// Returns:
|
||||
// - *KiroTokenData: The parsed token data
|
||||
// - error: An error if file reading or parsing fails
|
||||
func (k *KiroAuth) LoadTokenFromFile(tokenFile string) (*KiroTokenData, error) {
|
||||
// Expand ~ to home directory
|
||||
if strings.HasPrefix(tokenFile, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
tokenFile = filepath.Join(home, tokenFile[1:])
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(tokenFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token file: %w", err)
|
||||
}
|
||||
|
||||
var tokenData KiroTokenData
|
||||
if err := json.Unmarshal(data, &tokenData); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||
}
|
||||
|
||||
return &tokenData, nil
|
||||
}
|
||||
|
||||
// IsTokenExpired checks if the token has expired.
|
||||
// This method parses the expiration timestamp and compares it with the current time.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenData: The token data to check
|
||||
//
|
||||
// Returns:
|
||||
// - bool: True if the token has expired, false otherwise
|
||||
func (k *KiroAuth) IsTokenExpired(tokenData *KiroTokenData) bool {
|
||||
if tokenData.ExpiresAt == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
// Try alternate format
|
||||
expiresAt, err = time.Parse("2006-01-02T15:04:05.000Z", tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return time.Now().After(expiresAt)
|
||||
}
|
||||
|
||||
// makeRequest sends a REST-style GET request to the CodeWhisperer API.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the request
|
||||
// - path: The API path (e.g., "getUsageLimits")
|
||||
// - tokenData: The token data containing access token, refresh token, and profile ARN
|
||||
// - queryParams: Query parameters to add to the URL
|
||||
//
|
||||
// Returns:
|
||||
// - []byte: The response body
|
||||
// - error: An error if the request fails
|
||||
func (k *KiroAuth) makeRequest(ctx context.Context, path string, tokenData *KiroTokenData, queryParams map[string]string) ([]byte, error) {
|
||||
// Get endpoint from profileArn (defaults to us-east-1 if empty)
|
||||
profileArn := queryParams["profileArn"]
|
||||
endpoint := GetKiroAPIEndpointFromProfileArn(profileArn)
|
||||
url := buildURL(endpoint, path, queryParams)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
accountKey := GetAccountKey(tokenData.ClientID, tokenData.RefreshToken)
|
||||
setRuntimeHeaders(req, tokenData.AccessToken, accountKey)
|
||||
|
||||
resp, err := k.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("failed to close response body: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// GetUsageLimits retrieves usage information from the CodeWhisperer API.
|
||||
// This method fetches the current usage statistics and subscription information.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the request
|
||||
// - tokenData: The token data containing access token and profile ARN
|
||||
//
|
||||
// Returns:
|
||||
// - *KiroUsageInfo: The usage information
|
||||
// - error: An error if the request fails
|
||||
func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData) (*KiroUsageInfo, error) {
|
||||
queryParams := map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
"profileArn": tokenData.ProfileArn,
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
}
|
||||
|
||||
body, err := k.makeRequest(ctx, pathGetUsageLimits, tokenData, queryParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result struct {
|
||||
SubscriptionInfo struct {
|
||||
SubscriptionTitle string `json:"subscriptionTitle"`
|
||||
} `json:"subscriptionInfo"`
|
||||
UsageBreakdownList []struct {
|
||||
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
|
||||
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
|
||||
} `json:"usageBreakdownList"`
|
||||
NextDateReset float64 `json:"nextDateReset"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse usage response: %w", err)
|
||||
}
|
||||
|
||||
usage := &KiroUsageInfo{
|
||||
SubscriptionTitle: result.SubscriptionInfo.SubscriptionTitle,
|
||||
NextReset: fmt.Sprintf("%v", result.NextDateReset),
|
||||
}
|
||||
|
||||
if len(result.UsageBreakdownList) > 0 {
|
||||
usage.CurrentUsage = result.UsageBreakdownList[0].CurrentUsageWithPrecision
|
||||
usage.UsageLimit = result.UsageBreakdownList[0].UsageLimitWithPrecision
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// ListAvailableModels retrieves available models from the CodeWhisperer API.
|
||||
// This method fetches the list of AI models available for the authenticated user.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the request
|
||||
// - tokenData: The token data containing access token and profile ARN
|
||||
//
|
||||
// Returns:
|
||||
// - []*KiroModel: The list of available models
|
||||
// - error: An error if the request fails
|
||||
func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroTokenData) ([]*KiroModel, error) {
|
||||
queryParams := map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
"profileArn": tokenData.ProfileArn,
|
||||
}
|
||||
|
||||
body, err := k.makeRequest(ctx, pathListAvailableModels, tokenData, queryParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Models []struct {
|
||||
ModelID string `json:"modelId"`
|
||||
ModelName string `json:"modelName"`
|
||||
Description string `json:"description"`
|
||||
RateMultiplier float64 `json:"rateMultiplier"`
|
||||
RateUnit string `json:"rateUnit"`
|
||||
TokenLimits *struct {
|
||||
MaxInputTokens int `json:"maxInputTokens"`
|
||||
} `json:"tokenLimits"`
|
||||
} `json:"models"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse models response: %w", err)
|
||||
}
|
||||
|
||||
models := make([]*KiroModel, 0, len(result.Models))
|
||||
for _, m := range result.Models {
|
||||
maxInputTokens := 0
|
||||
if m.TokenLimits != nil {
|
||||
maxInputTokens = m.TokenLimits.MaxInputTokens
|
||||
}
|
||||
models = append(models, &KiroModel{
|
||||
ModelID: m.ModelID,
|
||||
ModelName: m.ModelName,
|
||||
Description: m.Description,
|
||||
RateMultiplier: m.RateMultiplier,
|
||||
RateUnit: m.RateUnit,
|
||||
MaxInputTokens: maxInputTokens,
|
||||
})
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// CreateTokenStorage creates a new KiroTokenStorage from token data.
|
||||
// This method converts the token data into a storage structure suitable for persistence.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenData: The token data to convert
|
||||
//
|
||||
// Returns:
|
||||
// - *KiroTokenStorage: A new token storage instance
|
||||
func (k *KiroAuth) CreateTokenStorage(tokenData *KiroTokenData) *KiroTokenStorage {
|
||||
return &KiroTokenStorage{
|
||||
AccessToken: tokenData.AccessToken,
|
||||
RefreshToken: tokenData.RefreshToken,
|
||||
ProfileArn: tokenData.ProfileArn,
|
||||
ExpiresAt: tokenData.ExpiresAt,
|
||||
AuthMethod: tokenData.AuthMethod,
|
||||
Provider: tokenData.Provider,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
ClientID: tokenData.ClientID,
|
||||
ClientSecret: tokenData.ClientSecret,
|
||||
Region: tokenData.Region,
|
||||
StartURL: tokenData.StartURL,
|
||||
Email: tokenData.Email,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateToken checks if the token is valid by making a test API call.
|
||||
// This method verifies the token by attempting to fetch usage limits.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the request
|
||||
// - tokenData: The token data to validate
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the token is invalid
|
||||
func (k *KiroAuth) ValidateToken(ctx context.Context, tokenData *KiroTokenData) error {
|
||||
_, err := k.GetUsageLimits(ctx, tokenData)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateTokenStorage updates an existing token storage with new token data.
|
||||
// This method refreshes the token storage with newly obtained access and refresh tokens.
|
||||
//
|
||||
// Parameters:
|
||||
// - storage: The existing token storage to update
|
||||
// - tokenData: The new token data to apply
|
||||
func (k *KiroAuth) UpdateTokenStorage(storage *KiroTokenStorage, tokenData *KiroTokenData) {
|
||||
storage.AccessToken = tokenData.AccessToken
|
||||
storage.RefreshToken = tokenData.RefreshToken
|
||||
storage.ProfileArn = tokenData.ProfileArn
|
||||
storage.ExpiresAt = tokenData.ExpiresAt
|
||||
storage.AuthMethod = tokenData.AuthMethod
|
||||
storage.Provider = tokenData.Provider
|
||||
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||
if tokenData.ClientID != "" {
|
||||
storage.ClientID = tokenData.ClientID
|
||||
}
|
||||
if tokenData.ClientSecret != "" {
|
||||
storage.ClientSecret = tokenData.ClientSecret
|
||||
}
|
||||
if tokenData.Region != "" {
|
||||
storage.Region = tokenData.Region
|
||||
}
|
||||
if tokenData.StartURL != "" {
|
||||
storage.StartURL = tokenData.StartURL
|
||||
}
|
||||
if tokenData.Email != "" {
|
||||
storage.Email = tokenData.Email
|
||||
}
|
||||
}
|
||||
751
internal/auth/kiro/aws_test.go
Normal file
751
internal/auth/kiro/aws_test.go
Normal file
@@ -0,0 +1,751 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractEmailFromJWT(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty token",
|
||||
token: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Invalid token format",
|
||||
token: "not.a.valid.jwt",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Invalid token - not base64",
|
||||
token: "xxx.yyy.zzz",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Valid JWT with email",
|
||||
token: createTestJWT(map[string]any{"email": "test@example.com", "sub": "user123"}),
|
||||
expected: "test@example.com",
|
||||
},
|
||||
{
|
||||
name: "JWT without email but with preferred_username",
|
||||
token: createTestJWT(map[string]any{"preferred_username": "user@domain.com", "sub": "user123"}),
|
||||
expected: "user@domain.com",
|
||||
},
|
||||
{
|
||||
name: "JWT with email-like sub",
|
||||
token: createTestJWT(map[string]any{"sub": "another@test.com"}),
|
||||
expected: "another@test.com",
|
||||
},
|
||||
{
|
||||
name: "JWT without any email fields",
|
||||
token: createTestJWT(map[string]any{"sub": "user123", "name": "Test User"}),
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ExtractEmailFromJWT(tt.token)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ExtractEmailFromJWT() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeEmailForFilename(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty email",
|
||||
email: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Simple email",
|
||||
email: "user@example.com",
|
||||
expected: "user@example.com",
|
||||
},
|
||||
{
|
||||
name: "Email with space",
|
||||
email: "user name@example.com",
|
||||
expected: "user_name@example.com",
|
||||
},
|
||||
{
|
||||
name: "Email with special chars",
|
||||
email: "user:name@example.com",
|
||||
expected: "user_name@example.com",
|
||||
},
|
||||
{
|
||||
name: "Email with multiple special chars",
|
||||
email: "user/name:test@example.com",
|
||||
expected: "user_name_test@example.com",
|
||||
},
|
||||
{
|
||||
name: "Path traversal attempt",
|
||||
email: "../../../etc/passwd",
|
||||
expected: "_.__.__._etc_passwd",
|
||||
},
|
||||
{
|
||||
name: "Path traversal with backslash",
|
||||
email: `..\..\..\..\windows\system32`,
|
||||
expected: "_.__.__.__._windows_system32",
|
||||
},
|
||||
{
|
||||
name: "Null byte injection attempt",
|
||||
email: "user\x00@evil.com",
|
||||
expected: "user_@evil.com",
|
||||
},
|
||||
// URL-encoded path traversal tests
|
||||
{
|
||||
name: "URL-encoded slash",
|
||||
email: "user%2Fpath@example.com",
|
||||
expected: "user_path@example.com",
|
||||
},
|
||||
{
|
||||
name: "URL-encoded backslash",
|
||||
email: "user%5Cpath@example.com",
|
||||
expected: "user_path@example.com",
|
||||
},
|
||||
{
|
||||
name: "URL-encoded dot",
|
||||
email: "%2E%2E%2Fetc%2Fpasswd",
|
||||
expected: "___etc_passwd",
|
||||
},
|
||||
{
|
||||
name: "URL-encoded null",
|
||||
email: "user%00@evil.com",
|
||||
expected: "user_@evil.com",
|
||||
},
|
||||
{
|
||||
name: "Double URL-encoding attack",
|
||||
email: "%252F%252E%252E",
|
||||
expected: "_252F_252E_252E", // % replaced with _, remaining chars preserved (safe)
|
||||
},
|
||||
{
|
||||
name: "Mixed case URL-encoding",
|
||||
email: "%2f%2F%5c%5C",
|
||||
expected: "____",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeEmailForFilename(tt.email)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeEmailForFilename() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// createTestJWT creates a test JWT token with the given claims
|
||||
func createTestJWT(claims map[string]any) string {
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
|
||||
|
||||
payloadBytes, _ := json.Marshal(claims)
|
||||
payload := base64.RawURLEncoding.EncodeToString(payloadBytes)
|
||||
|
||||
signature := base64.RawURLEncoding.EncodeToString([]byte("fake-signature"))
|
||||
|
||||
return header + "." + payload + "." + signature
|
||||
}
|
||||
|
||||
func TestExtractIDCIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
startURL string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty URL",
|
||||
startURL: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Standard IDC URL with d- prefix",
|
||||
startURL: "https://d-1234567890.awsapps.com/start",
|
||||
expected: "d-1234567890",
|
||||
},
|
||||
{
|
||||
name: "IDC URL with company name",
|
||||
startURL: "https://my-company.awsapps.com/start",
|
||||
expected: "my-company",
|
||||
},
|
||||
{
|
||||
name: "IDC URL with simple name",
|
||||
startURL: "https://acme-corp.awsapps.com/start",
|
||||
expected: "acme-corp",
|
||||
},
|
||||
{
|
||||
name: "IDC URL without https",
|
||||
startURL: "http://d-9876543210.awsapps.com/start",
|
||||
expected: "d-9876543210",
|
||||
},
|
||||
{
|
||||
name: "IDC URL with subdomain only",
|
||||
startURL: "https://test.awsapps.com/start",
|
||||
expected: "test",
|
||||
},
|
||||
{
|
||||
name: "Builder ID URL",
|
||||
startURL: "https://view.awsapps.com/start",
|
||||
expected: "view",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ExtractIDCIdentifier(tt.startURL)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ExtractIDCIdentifier() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateTokenFileName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenData *KiroTokenData
|
||||
exact string // exact match (for cases with email)
|
||||
prefix string // prefix match (for cases without email, where sequence is appended)
|
||||
}{
|
||||
{
|
||||
name: "IDC with email",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "idc",
|
||||
Email: "user@example.com",
|
||||
StartURL: "https://d-1234567890.awsapps.com/start",
|
||||
},
|
||||
exact: "kiro-idc-user-example-com.json",
|
||||
},
|
||||
{
|
||||
name: "IDC without email but with startUrl",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "idc",
|
||||
Email: "",
|
||||
StartURL: "https://d-1234567890.awsapps.com/start",
|
||||
},
|
||||
prefix: "kiro-idc-d-1234567890-",
|
||||
},
|
||||
{
|
||||
name: "IDC with company name in startUrl",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "idc",
|
||||
Email: "",
|
||||
StartURL: "https://my-company.awsapps.com/start",
|
||||
},
|
||||
prefix: "kiro-idc-my-company-",
|
||||
},
|
||||
{
|
||||
name: "IDC without email and without startUrl",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "idc",
|
||||
Email: "",
|
||||
StartURL: "",
|
||||
},
|
||||
prefix: "kiro-idc-",
|
||||
},
|
||||
{
|
||||
name: "Builder ID with email",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "builder-id",
|
||||
Email: "user@gmail.com",
|
||||
StartURL: "https://view.awsapps.com/start",
|
||||
},
|
||||
exact: "kiro-builder-id-user-gmail-com.json",
|
||||
},
|
||||
{
|
||||
name: "Builder ID without email",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "builder-id",
|
||||
Email: "",
|
||||
StartURL: "https://view.awsapps.com/start",
|
||||
},
|
||||
prefix: "kiro-builder-id-",
|
||||
},
|
||||
{
|
||||
name: "Social auth with email",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "google",
|
||||
Email: "user@gmail.com",
|
||||
},
|
||||
exact: "kiro-google-user-gmail-com.json",
|
||||
},
|
||||
{
|
||||
name: "Empty auth method",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "",
|
||||
Email: "",
|
||||
},
|
||||
prefix: "kiro-unknown-",
|
||||
},
|
||||
{
|
||||
name: "Email with special characters",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "idc",
|
||||
Email: "user.name+tag@sub.example.com",
|
||||
StartURL: "https://d-1234567890.awsapps.com/start",
|
||||
},
|
||||
exact: "kiro-idc-user-name+tag-sub-example-com.json",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GenerateTokenFileName(tt.tokenData)
|
||||
if tt.exact != "" {
|
||||
if result != tt.exact {
|
||||
t.Errorf("GenerateTokenFileName() = %q, want %q", result, tt.exact)
|
||||
}
|
||||
} else if tt.prefix != "" {
|
||||
if !strings.HasPrefix(result, tt.prefix) || !strings.HasSuffix(result, ".json") {
|
||||
t.Errorf("GenerateTokenFileName() = %q, want prefix %q with .json suffix", result, tt.prefix)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseProfileARN(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
arn string
|
||||
expected *ProfileARN
|
||||
}{
|
||||
{
|
||||
name: "Empty ARN",
|
||||
arn: "",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid format - too few parts",
|
||||
arn: "arn:aws:codewhisperer",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid prefix - not arn",
|
||||
arn: "notarn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid service - not codewhisperer",
|
||||
arn: "arn:aws:s3:us-east-1:123456789012:bucket/mybucket",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid region - no hyphen",
|
||||
arn: "arn:aws:codewhisperer:useast1:123456789012:profile/ABC",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Empty partition",
|
||||
arn: "arn::codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Empty region",
|
||||
arn: "arn:aws:codewhisperer::123456789012:profile/ABC",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - us-east-1",
|
||||
arn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL",
|
||||
Partition: "aws",
|
||||
Service: "codewhisperer",
|
||||
Region: "us-east-1",
|
||||
AccountID: "123456789012",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "ABCDEFGHIJKL",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - ap-southeast-1",
|
||||
arn: "arn:aws:codewhisperer:ap-southeast-1:987654321098:profile/ZYXWVUTSRQ",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws:codewhisperer:ap-southeast-1:987654321098:profile/ZYXWVUTSRQ",
|
||||
Partition: "aws",
|
||||
Service: "codewhisperer",
|
||||
Region: "ap-southeast-1",
|
||||
AccountID: "987654321098",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "ZYXWVUTSRQ",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - eu-west-1",
|
||||
arn: "arn:aws:codewhisperer:eu-west-1:111222333444:profile/PROFILE123",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws:codewhisperer:eu-west-1:111222333444:profile/PROFILE123",
|
||||
Partition: "aws",
|
||||
Service: "codewhisperer",
|
||||
Region: "eu-west-1",
|
||||
AccountID: "111222333444",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "PROFILE123",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - aws-cn partition",
|
||||
arn: "arn:aws-cn:codewhisperer:cn-north-1:123456789012:profile/CHINAID",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws-cn:codewhisperer:cn-north-1:123456789012:profile/CHINAID",
|
||||
Partition: "aws-cn",
|
||||
Service: "codewhisperer",
|
||||
Region: "cn-north-1",
|
||||
AccountID: "123456789012",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "CHINAID",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - resource without slash",
|
||||
arn: "arn:aws:codewhisperer:us-west-2:123456789012:profile",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws:codewhisperer:us-west-2:123456789012:profile",
|
||||
Partition: "aws",
|
||||
Service: "codewhisperer",
|
||||
Region: "us-west-2",
|
||||
AccountID: "123456789012",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - resource with colon",
|
||||
arn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC:extra",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC:extra",
|
||||
Partition: "aws",
|
||||
Service: "codewhisperer",
|
||||
Region: "us-east-1",
|
||||
AccountID: "123456789012",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "ABC:extra",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ParseProfileARN(tt.arn)
|
||||
if tt.expected == nil {
|
||||
if result != nil {
|
||||
t.Errorf("ParseProfileARN(%q) = %+v, want nil", tt.arn, result)
|
||||
}
|
||||
return
|
||||
}
|
||||
if result == nil {
|
||||
t.Errorf("ParseProfileARN(%q) = nil, want %+v", tt.arn, tt.expected)
|
||||
return
|
||||
}
|
||||
if result.Raw != tt.expected.Raw {
|
||||
t.Errorf("Raw = %q, want %q", result.Raw, tt.expected.Raw)
|
||||
}
|
||||
if result.Partition != tt.expected.Partition {
|
||||
t.Errorf("Partition = %q, want %q", result.Partition, tt.expected.Partition)
|
||||
}
|
||||
if result.Service != tt.expected.Service {
|
||||
t.Errorf("Service = %q, want %q", result.Service, tt.expected.Service)
|
||||
}
|
||||
if result.Region != tt.expected.Region {
|
||||
t.Errorf("Region = %q, want %q", result.Region, tt.expected.Region)
|
||||
}
|
||||
if result.AccountID != tt.expected.AccountID {
|
||||
t.Errorf("AccountID = %q, want %q", result.AccountID, tt.expected.AccountID)
|
||||
}
|
||||
if result.ResourceType != tt.expected.ResourceType {
|
||||
t.Errorf("ResourceType = %q, want %q", result.ResourceType, tt.expected.ResourceType)
|
||||
}
|
||||
if result.ResourceID != tt.expected.ResourceID {
|
||||
t.Errorf("ResourceID = %q, want %q", result.ResourceID, tt.expected.ResourceID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRegionFromProfileArn(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileArn string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty ARN",
|
||||
profileArn: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Invalid ARN",
|
||||
profileArn: "invalid-arn",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - us-east-1",
|
||||
profileArn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - ap-southeast-1",
|
||||
profileArn: "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
|
||||
expected: "ap-southeast-1",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - eu-central-1",
|
||||
profileArn: "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
|
||||
expected: "eu-central-1",
|
||||
},
|
||||
{
|
||||
name: "Non-codewhisperer ARN",
|
||||
profileArn: "arn:aws:s3:us-east-1:123456789012:bucket/mybucket",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ExtractRegionFromProfileArn(tt.profileArn)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ExtractRegionFromProfileArn(%q) = %q, want %q", tt.profileArn, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetKiroAPIEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
region string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty region - defaults to us-east-1",
|
||||
region: "",
|
||||
expected: "https://q.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "us-east-1",
|
||||
region: "us-east-1",
|
||||
expected: "https://q.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "us-west-2",
|
||||
region: "us-west-2",
|
||||
expected: "https://q.us-west-2.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "ap-southeast-1",
|
||||
region: "ap-southeast-1",
|
||||
expected: "https://q.ap-southeast-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "eu-west-1",
|
||||
region: "eu-west-1",
|
||||
expected: "https://q.eu-west-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "cn-north-1",
|
||||
region: "cn-north-1",
|
||||
expected: "https://q.cn-north-1.amazonaws.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetKiroAPIEndpoint(tt.region)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetKiroAPIEndpoint(%q) = %q, want %q", tt.region, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetKiroAPIEndpointFromProfileArn(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileArn string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty ARN - defaults to us-east-1",
|
||||
profileArn: "",
|
||||
expected: "https://q.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "Invalid ARN - defaults to us-east-1",
|
||||
profileArn: "invalid-arn",
|
||||
expected: "https://q.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - us-east-1",
|
||||
profileArn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||
expected: "https://q.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - ap-southeast-1",
|
||||
profileArn: "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
|
||||
expected: "https://q.ap-southeast-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - eu-central-1",
|
||||
profileArn: "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
|
||||
expected: "https://q.eu-central-1.amazonaws.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetKiroAPIEndpointFromProfileArn(tt.profileArn)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetKiroAPIEndpointFromProfileArn(%q) = %q, want %q", tt.profileArn, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCodeWhispererLegacyEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
region string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty region - defaults to us-east-1",
|
||||
region: "",
|
||||
expected: "https://codewhisperer.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "us-east-1",
|
||||
region: "us-east-1",
|
||||
expected: "https://codewhisperer.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "us-west-2",
|
||||
region: "us-west-2",
|
||||
expected: "https://codewhisperer.us-west-2.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "ap-northeast-1",
|
||||
region: "ap-northeast-1",
|
||||
expected: "https://codewhisperer.ap-northeast-1.amazonaws.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetCodeWhispererLegacyEndpoint(tt.region)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetCodeWhispererLegacyEndpoint(%q) = %q, want %q", tt.region, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRegionFromMetadata(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
metadata map[string]interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Nil metadata - defaults to us-east-1",
|
||||
metadata: nil,
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "Empty metadata - defaults to us-east-1",
|
||||
metadata: map[string]interface{}{},
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "Priority 1: api_region override",
|
||||
metadata: map[string]interface{}{
|
||||
"api_region": "eu-west-1",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||
},
|
||||
expected: "eu-west-1",
|
||||
},
|
||||
{
|
||||
name: "Priority 2: profile_arn when api_region is empty",
|
||||
metadata: map[string]interface{}{
|
||||
"api_region": "",
|
||||
"profile_arn": "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
|
||||
},
|
||||
expected: "ap-southeast-1",
|
||||
},
|
||||
{
|
||||
name: "Priority 2: profile_arn when api_region is missing",
|
||||
metadata: map[string]interface{}{
|
||||
"profile_arn": "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
|
||||
},
|
||||
expected: "eu-central-1",
|
||||
},
|
||||
{
|
||||
name: "Fallback: default when profile_arn is invalid",
|
||||
metadata: map[string]interface{}{
|
||||
"profile_arn": "invalid-arn",
|
||||
},
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "Fallback: default when profile_arn is empty",
|
||||
metadata: map[string]interface{}{
|
||||
"profile_arn": "",
|
||||
},
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "OIDC region is NOT used for API region",
|
||||
metadata: map[string]interface{}{
|
||||
"region": "ap-northeast-2", // OIDC region - should be ignored
|
||||
},
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "api_region takes precedence over OIDC region",
|
||||
metadata: map[string]interface{}{
|
||||
"api_region": "us-west-2",
|
||||
"region": "ap-northeast-2", // OIDC region - should be ignored
|
||||
},
|
||||
expected: "us-west-2",
|
||||
},
|
||||
{
|
||||
name: "Non-string api_region is ignored",
|
||||
metadata: map[string]interface{}{
|
||||
"api_region": 123, // wrong type
|
||||
"profile_arn": "arn:aws:codewhisperer:ap-south-1:123456789012:profile/ABC",
|
||||
},
|
||||
expected: "ap-south-1",
|
||||
},
|
||||
{
|
||||
name: "Non-string profile_arn is ignored",
|
||||
metadata: map[string]interface{}{
|
||||
"profile_arn": 123, // wrong type
|
||||
},
|
||||
expected: "us-east-1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ExtractRegionFromMetadata(tt.metadata)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ExtractRegionFromMetadata(%v) = %q, want %q", tt.metadata, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
247
internal/auth/kiro/background_refresh.go
Normal file
247
internal/auth/kiro/background_refresh.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
ID string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresAt time.Time
|
||||
LastVerified time.Time
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AuthMethod string
|
||||
Provider string
|
||||
StartURL string
|
||||
Region string
|
||||
}
|
||||
|
||||
type TokenRepository interface {
|
||||
FindOldestUnverified(limit int) []*Token
|
||||
UpdateToken(token *Token) error
|
||||
}
|
||||
|
||||
type RefresherOption func(*BackgroundRefresher)
|
||||
|
||||
func WithInterval(interval time.Duration) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.interval = interval
|
||||
}
|
||||
}
|
||||
|
||||
func WithBatchSize(size int) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.batchSize = size
|
||||
}
|
||||
}
|
||||
|
||||
func WithConcurrency(concurrency int) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.concurrency = concurrency
|
||||
}
|
||||
}
|
||||
|
||||
type BackgroundRefresher struct {
|
||||
interval time.Duration
|
||||
batchSize int
|
||||
concurrency int
|
||||
tokenRepo TokenRepository
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
oauth *KiroOAuth
|
||||
ssoClient *SSOOIDCClient
|
||||
callbackMu sync.RWMutex // 保护回调函数的并发访问
|
||||
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
|
||||
}
|
||||
|
||||
func NewBackgroundRefresher(repo TokenRepository, opts ...RefresherOption) *BackgroundRefresher {
|
||||
r := &BackgroundRefresher{
|
||||
interval: time.Minute,
|
||||
batchSize: 50,
|
||||
concurrency: 10,
|
||||
tokenRepo: repo,
|
||||
stopCh: make(chan struct{}),
|
||||
oauth: nil, // Lazy init - will be set when config available
|
||||
ssoClient: nil, // Lazy init - will be set when config available
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// WithConfig sets the configuration for OAuth and SSO clients.
|
||||
func WithConfig(cfg *config.Config) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.oauth = NewKiroOAuth(cfg)
|
||||
r.ssoClient = NewSSOOIDCClient(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// WithOnTokenRefreshed sets the callback function to be called when a token is successfully refreshed.
|
||||
// The callback receives the token ID (filename) and the new token data.
|
||||
// This allows external components (e.g., Watcher) to be notified of token updates.
|
||||
func WithOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.callbackMu.Lock()
|
||||
r.onTokenRefreshed = callback
|
||||
r.callbackMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) Start(ctx context.Context) {
|
||||
r.wg.Add(1)
|
||||
go func() {
|
||||
defer r.wg.Done()
|
||||
ticker := time.NewTicker(r.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
r.refreshBatch(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-r.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
r.refreshBatch(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) Stop() {
|
||||
close(r.stopCh)
|
||||
r.wg.Wait()
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) refreshBatch(ctx context.Context) {
|
||||
tokens := r.tokenRepo.FindOldestUnverified(r.batchSize)
|
||||
if len(tokens) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sem := semaphore.NewWeighted(int64(r.concurrency))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i, token := range tokens {
|
||||
if i > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-r.stopCh:
|
||||
return
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
if err := sem.Acquire(ctx, 1); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(t *Token) {
|
||||
defer wg.Done()
|
||||
defer sem.Release(1)
|
||||
r.refreshSingle(ctx, t)
|
||||
}(token)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) {
|
||||
// Normalize auth method to lowercase for case-insensitive matching
|
||||
authMethod := strings.ToLower(token.AuthMethod)
|
||||
|
||||
// Create refresh function based on auth method
|
||||
refreshFunc := func(ctx context.Context) (*KiroTokenData, error) {
|
||||
switch authMethod {
|
||||
case "idc":
|
||||
return r.ssoClient.RefreshTokenWithRegion(
|
||||
ctx,
|
||||
token.ClientID,
|
||||
token.ClientSecret,
|
||||
token.RefreshToken,
|
||||
token.Region,
|
||||
token.StartURL,
|
||||
)
|
||||
case "builder-id":
|
||||
return r.ssoClient.RefreshToken(
|
||||
ctx,
|
||||
token.ClientID,
|
||||
token.ClientSecret,
|
||||
token.RefreshToken,
|
||||
)
|
||||
default:
|
||||
return r.oauth.RefreshTokenWithFingerprint(ctx, token.RefreshToken, token.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Use graceful degradation for better reliability
|
||||
result := RefreshWithGracefulDegradation(
|
||||
ctx,
|
||||
refreshFunc,
|
||||
token.AccessToken,
|
||||
token.ExpiresAt,
|
||||
)
|
||||
|
||||
if result.Error != nil {
|
||||
log.Printf("failed to refresh token %s: %v", token.ID, result.Error)
|
||||
return
|
||||
}
|
||||
|
||||
newTokenData := result.TokenData
|
||||
if result.UsedFallback {
|
||||
log.Printf("token %s: using existing token as fallback (refresh failed but token still valid)", token.ID)
|
||||
// Don't update the token file if we're using fallback
|
||||
// Just update LastVerified to prevent immediate re-check
|
||||
token.LastVerified = time.Now()
|
||||
return
|
||||
}
|
||||
|
||||
token.AccessToken = newTokenData.AccessToken
|
||||
if newTokenData.RefreshToken != "" {
|
||||
token.RefreshToken = newTokenData.RefreshToken
|
||||
}
|
||||
token.LastVerified = time.Now()
|
||||
|
||||
if newTokenData.ExpiresAt != "" {
|
||||
if expTime, parseErr := time.Parse(time.RFC3339, newTokenData.ExpiresAt); parseErr == nil {
|
||||
token.ExpiresAt = expTime
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.tokenRepo.UpdateToken(token); err != nil {
|
||||
log.Printf("failed to update token %s: %v", token.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 方案 A: 刷新成功后触发回调,通知 Watcher 更新内存中的 Auth 对象
|
||||
r.callbackMu.RLock()
|
||||
callback := r.onTokenRefreshed
|
||||
r.callbackMu.RUnlock()
|
||||
|
||||
if callback != nil {
|
||||
// 使用 defer recover 隔离回调 panic,防止崩溃整个进程
|
||||
func() {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
log.Printf("background refresh: callback panic for token %s: %v", token.ID, rec)
|
||||
}
|
||||
}()
|
||||
log.Printf("background refresh: notifying token refresh callback for %s", token.ID)
|
||||
callback(token.ID, newTokenData)
|
||||
}()
|
||||
}
|
||||
}
|
||||
153
internal/auth/kiro/codewhisperer_client.go
Normal file
153
internal/auth/kiro/codewhisperer_client.go
Normal file
@@ -0,0 +1,153 @@
|
||||
// Package kiro provides CodeWhisperer API client for fetching user info.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// CodeWhispererClient handles CodeWhisperer API calls.
|
||||
type CodeWhispererClient struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// UsageLimitsResponse represents the getUsageLimits API response.
|
||||
type UsageLimitsResponse struct {
|
||||
DaysUntilReset *int `json:"daysUntilReset,omitempty"`
|
||||
NextDateReset *float64 `json:"nextDateReset,omitempty"`
|
||||
UserInfo *UserInfo `json:"userInfo,omitempty"`
|
||||
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
|
||||
UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList,omitempty"`
|
||||
}
|
||||
|
||||
// UserInfo contains user information from the API.
|
||||
type UserInfo struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
UserID string `json:"userId,omitempty"`
|
||||
}
|
||||
|
||||
// SubscriptionInfo contains subscription details.
|
||||
type SubscriptionInfo struct {
|
||||
SubscriptionTitle string `json:"subscriptionTitle,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
// UsageBreakdown contains usage details.
|
||||
type UsageBreakdown struct {
|
||||
UsageLimit *int `json:"usageLimit,omitempty"`
|
||||
CurrentUsage *int `json:"currentUsage,omitempty"`
|
||||
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision,omitempty"`
|
||||
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision,omitempty"`
|
||||
NextDateReset *float64 `json:"nextDateReset,omitempty"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
ResourceType string `json:"resourceType,omitempty"`
|
||||
}
|
||||
|
||||
// NewCodeWhispererClient creates a new CodeWhisperer client.
|
||||
func NewCodeWhispererClient(cfg *config.Config, machineID string) *CodeWhispererClient {
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
return &CodeWhispererClient{
|
||||
httpClient: client,
|
||||
}
|
||||
}
|
||||
|
||||
// GetUsageLimits fetches usage limits and user info from CodeWhisperer API.
|
||||
// This is the recommended way to get user email after login.
|
||||
func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken, clientID, refreshToken, profileArn string) (*UsageLimitsResponse, error) {
|
||||
queryParams := map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
}
|
||||
// Determine endpoint based on profileArn region
|
||||
endpoint := GetKiroAPIEndpointFromProfileArn(profileArn)
|
||||
if profileArn != "" {
|
||||
queryParams["profileArn"] = profileArn
|
||||
} else {
|
||||
queryParams["isEmailRequired"] = "true"
|
||||
}
|
||||
url := buildURL(endpoint, pathGetUsageLimits, queryParams)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
accountKey := GetAccountKey(clientID, refreshToken)
|
||||
setRuntimeHeaders(req, accessToken, accountKey)
|
||||
|
||||
log.Debugf("codewhisperer: GET %s", url)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("codewhisperer: status=%d, body=%s", resp.StatusCode, string(body))
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result UsageLimitsResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// FetchUserEmailFromAPI fetches user email using CodeWhisperer getUsageLimits API.
|
||||
// This is more reliable than JWT parsing as it uses the official API.
|
||||
func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessToken, clientID, refreshToken string) string {
|
||||
resp, err := c.GetUsageLimits(ctx, accessToken, clientID, refreshToken, "")
|
||||
if err != nil {
|
||||
log.Debugf("codewhisperer: failed to get usage limits: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
if resp.UserInfo != nil && resp.UserInfo.Email != "" {
|
||||
log.Debugf("codewhisperer: got email from API: %s", resp.UserInfo.Email)
|
||||
return resp.UserInfo.Email
|
||||
}
|
||||
|
||||
log.Debugf("codewhisperer: no email in response")
|
||||
return ""
|
||||
}
|
||||
|
||||
// FetchUserEmailWithFallback fetches user email with multiple fallback methods.
|
||||
// Priority: 1. CodeWhisperer API 2. userinfo endpoint 3. JWT parsing
|
||||
func FetchUserEmailWithFallback(ctx context.Context, cfg *config.Config, accessToken, clientID, refreshToken string) string {
|
||||
// Method 1: Try CodeWhisperer API (most reliable)
|
||||
cwClient := NewCodeWhispererClient(cfg, "")
|
||||
email := cwClient.FetchUserEmailFromAPI(ctx, accessToken, clientID, refreshToken)
|
||||
if email != "" {
|
||||
return email
|
||||
}
|
||||
|
||||
// Method 2: Try SSO OIDC userinfo endpoint
|
||||
ssoClient := NewSSOOIDCClient(cfg)
|
||||
email = ssoClient.FetchUserEmail(ctx, accessToken)
|
||||
if email != "" {
|
||||
return email
|
||||
}
|
||||
|
||||
// Method 3: Fallback to JWT parsing
|
||||
return ExtractEmailFromJWT(accessToken)
|
||||
}
|
||||
112
internal/auth/kiro/cooldown.go
Normal file
112
internal/auth/kiro/cooldown.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
CooldownReason429 = "rate_limit_exceeded"
|
||||
CooldownReasonSuspended = "account_suspended"
|
||||
CooldownReasonQuotaExhausted = "quota_exhausted"
|
||||
|
||||
DefaultShortCooldown = 1 * time.Minute
|
||||
MaxShortCooldown = 5 * time.Minute
|
||||
LongCooldown = 24 * time.Hour
|
||||
)
|
||||
|
||||
type CooldownManager struct {
|
||||
mu sync.RWMutex
|
||||
cooldowns map[string]time.Time
|
||||
reasons map[string]string
|
||||
}
|
||||
|
||||
func NewCooldownManager() *CooldownManager {
|
||||
return &CooldownManager{
|
||||
cooldowns: make(map[string]time.Time),
|
||||
reasons: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) SetCooldown(tokenKey string, duration time.Duration, reason string) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
cm.cooldowns[tokenKey] = time.Now().Add(duration)
|
||||
cm.reasons[tokenKey] = reason
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) IsInCooldown(tokenKey string) bool {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
endTime, exists := cm.cooldowns[tokenKey]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(endTime)
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) GetRemainingCooldown(tokenKey string) time.Duration {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
endTime, exists := cm.cooldowns[tokenKey]
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
remaining := time.Until(endTime)
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) GetCooldownReason(tokenKey string) string {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return cm.reasons[tokenKey]
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) ClearCooldown(tokenKey string) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
delete(cm.cooldowns, tokenKey)
|
||||
delete(cm.reasons, tokenKey)
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) CleanupExpired() {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
now := time.Now()
|
||||
for tokenKey, endTime := range cm.cooldowns {
|
||||
if now.After(endTime) {
|
||||
delete(cm.cooldowns, tokenKey)
|
||||
delete(cm.reasons, tokenKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) StartCleanupRoutine(interval time.Duration, stopCh <-chan struct{}) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cm.CleanupExpired()
|
||||
case <-stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CalculateCooldownFor429(retryCount int) time.Duration {
|
||||
duration := DefaultShortCooldown * time.Duration(1<<retryCount)
|
||||
if duration > MaxShortCooldown {
|
||||
return MaxShortCooldown
|
||||
}
|
||||
return duration
|
||||
}
|
||||
|
||||
func CalculateCooldownUntilNextDay() time.Duration {
|
||||
now := time.Now()
|
||||
nextDay := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location())
|
||||
return time.Until(nextDay)
|
||||
}
|
||||
240
internal/auth/kiro/cooldown_test.go
Normal file
240
internal/auth/kiro/cooldown_test.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewCooldownManager(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
if cm == nil {
|
||||
t.Fatal("expected non-nil CooldownManager")
|
||||
}
|
||||
if cm.cooldowns == nil {
|
||||
t.Error("expected non-nil cooldowns map")
|
||||
}
|
||||
if cm.reasons == nil {
|
||||
t.Error("expected non-nil reasons map")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCooldown(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
|
||||
|
||||
if !cm.IsInCooldown("token1") {
|
||||
t.Error("expected token to be in cooldown")
|
||||
}
|
||||
if cm.GetCooldownReason("token1") != CooldownReason429 {
|
||||
t.Errorf("expected reason %s, got %s", CooldownReason429, cm.GetCooldownReason("token1"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsInCooldown_NotSet(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
if cm.IsInCooldown("nonexistent") {
|
||||
t.Error("expected non-existent token to not be in cooldown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsInCooldown_Expired(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
if cm.IsInCooldown("token1") {
|
||||
t.Error("expected expired cooldown to return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRemainingCooldown(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Second, CooldownReason429)
|
||||
|
||||
remaining := cm.GetRemainingCooldown("token1")
|
||||
if remaining <= 0 || remaining > 1*time.Second {
|
||||
t.Errorf("expected remaining cooldown between 0 and 1s, got %v", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRemainingCooldown_NotSet(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
remaining := cm.GetRemainingCooldown("nonexistent")
|
||||
if remaining != 0 {
|
||||
t.Errorf("expected 0 remaining for non-existent, got %v", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRemainingCooldown_Expired(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
remaining := cm.GetRemainingCooldown("token1")
|
||||
if remaining != 0 {
|
||||
t.Errorf("expected 0 remaining for expired, got %v", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCooldownReason(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
|
||||
|
||||
reason := cm.GetCooldownReason("token1")
|
||||
if reason != CooldownReasonSuspended {
|
||||
t.Errorf("expected reason %s, got %s", CooldownReasonSuspended, reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCooldownReason_NotSet(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
reason := cm.GetCooldownReason("nonexistent")
|
||||
if reason != "" {
|
||||
t.Errorf("expected empty reason for non-existent, got %s", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearCooldown(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
|
||||
cm.ClearCooldown("token1")
|
||||
|
||||
if cm.IsInCooldown("token1") {
|
||||
t.Error("expected cooldown to be cleared")
|
||||
}
|
||||
if cm.GetCooldownReason("token1") != "" {
|
||||
t.Error("expected reason to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearCooldown_NonExistent(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.ClearCooldown("nonexistent")
|
||||
}
|
||||
|
||||
func TestCleanupExpired(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("expired1", 1*time.Millisecond, CooldownReason429)
|
||||
cm.SetCooldown("expired2", 1*time.Millisecond, CooldownReason429)
|
||||
cm.SetCooldown("active", 1*time.Hour, CooldownReason429)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cm.CleanupExpired()
|
||||
|
||||
if cm.GetCooldownReason("expired1") != "" {
|
||||
t.Error("expected expired1 to be cleaned up")
|
||||
}
|
||||
if cm.GetCooldownReason("expired2") != "" {
|
||||
t.Error("expected expired2 to be cleaned up")
|
||||
}
|
||||
if cm.GetCooldownReason("active") != CooldownReason429 {
|
||||
t.Error("expected active to remain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownFor429_FirstRetry(t *testing.T) {
|
||||
duration := CalculateCooldownFor429(0)
|
||||
if duration != DefaultShortCooldown {
|
||||
t.Errorf("expected %v for retry 0, got %v", DefaultShortCooldown, duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownFor429_Exponential(t *testing.T) {
|
||||
d1 := CalculateCooldownFor429(1)
|
||||
d2 := CalculateCooldownFor429(2)
|
||||
|
||||
if d2 <= d1 {
|
||||
t.Errorf("expected d2 > d1, got d1=%v, d2=%v", d1, d2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownFor429_MaxCap(t *testing.T) {
|
||||
duration := CalculateCooldownFor429(10)
|
||||
if duration > MaxShortCooldown {
|
||||
t.Errorf("expected max %v, got %v", MaxShortCooldown, duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownUntilNextDay(t *testing.T) {
|
||||
duration := CalculateCooldownUntilNextDay()
|
||||
if duration <= 0 || duration > 24*time.Hour {
|
||||
t.Errorf("expected duration between 0 and 24h, got %v", duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldownManager_ConcurrentAccess(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
const numGoroutines = 50
|
||||
const numOperations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
tokenKey := "token" + string(rune('a'+id%10))
|
||||
for j := 0; j < numOperations; j++ {
|
||||
switch j % 6 {
|
||||
case 0:
|
||||
cm.SetCooldown(tokenKey, time.Duration(j)*time.Millisecond, CooldownReason429)
|
||||
case 1:
|
||||
cm.IsInCooldown(tokenKey)
|
||||
case 2:
|
||||
cm.GetRemainingCooldown(tokenKey)
|
||||
case 3:
|
||||
cm.GetCooldownReason(tokenKey)
|
||||
case 4:
|
||||
cm.ClearCooldown(tokenKey)
|
||||
case 5:
|
||||
cm.CleanupExpired()
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestCooldownReasonConstants(t *testing.T) {
|
||||
if CooldownReason429 != "rate_limit_exceeded" {
|
||||
t.Errorf("unexpected CooldownReason429: %s", CooldownReason429)
|
||||
}
|
||||
if CooldownReasonSuspended != "account_suspended" {
|
||||
t.Errorf("unexpected CooldownReasonSuspended: %s", CooldownReasonSuspended)
|
||||
}
|
||||
if CooldownReasonQuotaExhausted != "quota_exhausted" {
|
||||
t.Errorf("unexpected CooldownReasonQuotaExhausted: %s", CooldownReasonQuotaExhausted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConstants(t *testing.T) {
|
||||
if DefaultShortCooldown != 1*time.Minute {
|
||||
t.Errorf("unexpected DefaultShortCooldown: %v", DefaultShortCooldown)
|
||||
}
|
||||
if MaxShortCooldown != 5*time.Minute {
|
||||
t.Errorf("unexpected MaxShortCooldown: %v", MaxShortCooldown)
|
||||
}
|
||||
if LongCooldown != 24*time.Hour {
|
||||
t.Errorf("unexpected LongCooldown: %v", LongCooldown)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCooldown_OverwritesPrevious(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Hour, CooldownReason429)
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
|
||||
|
||||
reason := cm.GetCooldownReason("token1")
|
||||
if reason != CooldownReasonSuspended {
|
||||
t.Errorf("expected reason to be overwritten to %s, got %s", CooldownReasonSuspended, reason)
|
||||
}
|
||||
|
||||
remaining := cm.GetRemainingCooldown("token1")
|
||||
if remaining > 1*time.Minute {
|
||||
t.Errorf("expected remaining <= 1 minute, got %v", remaining)
|
||||
}
|
||||
}
|
||||
278
internal/auth/kiro/fingerprint.go
Normal file
278
internal/auth/kiro/fingerprint.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Fingerprint holds multi-dimensional fingerprint data for runtime request disguise.
|
||||
type Fingerprint struct {
|
||||
OIDCSDKVersion string // 3.7xx (AWS SDK JS)
|
||||
RuntimeSDKVersion string // 1.0.x (runtime API)
|
||||
StreamingSDKVersion string // 1.0.x (streaming API)
|
||||
OSType string // darwin/windows/linux
|
||||
OSVersion string
|
||||
NodeVersion string
|
||||
KiroVersion string
|
||||
KiroHash string // SHA256
|
||||
}
|
||||
|
||||
// FingerprintConfig holds external fingerprint overrides.
|
||||
type FingerprintConfig struct {
|
||||
OIDCSDKVersion string
|
||||
RuntimeSDKVersion string
|
||||
StreamingSDKVersion string
|
||||
OSType string
|
||||
OSVersion string
|
||||
NodeVersion string
|
||||
KiroVersion string
|
||||
KiroHash string
|
||||
}
|
||||
|
||||
// FingerprintManager manages per-account fingerprint generation and caching.
|
||||
type FingerprintManager struct {
|
||||
mu sync.RWMutex
|
||||
fingerprints map[string]*Fingerprint // tokenKey -> fingerprint
|
||||
rng *rand.Rand
|
||||
config *FingerprintConfig // External config (Optional)
|
||||
}
|
||||
|
||||
var (
|
||||
// SDK versions
|
||||
oidcSDKVersions = []string{
|
||||
"3.980.0", "3.975.0", "3.972.0", "3.808.0",
|
||||
"3.738.0", "3.737.0", "3.736.0", "3.735.0",
|
||||
}
|
||||
// SDKVersions for getUsageLimits/ListAvailableModels/GetProfile (runtime API)
|
||||
runtimeSDKVersions = []string{"1.0.0"}
|
||||
// SDKVersions for generateAssistantResponse (streaming API)
|
||||
streamingSDKVersions = []string{"1.0.27"}
|
||||
// Valid OS types
|
||||
osTypes = []string{"darwin", "windows", "linux"}
|
||||
// OS versions
|
||||
osVersions = map[string][]string{
|
||||
"darwin": {"25.2.0", "25.1.0", "25.0.0", "24.5.0", "24.4.0", "24.3.0"},
|
||||
"windows": {"10.0.26200", "10.0.26100", "10.0.22631", "10.0.22621", "10.0.19045"},
|
||||
"linux": {"6.12.0", "6.11.0", "6.8.0", "6.6.0", "6.5.0", "6.1.0"},
|
||||
}
|
||||
// Node versions
|
||||
nodeVersions = []string{
|
||||
"22.21.1", "22.21.0", "22.20.0", "22.19.0", "22.18.0",
|
||||
"20.18.0", "20.17.0", "20.16.0",
|
||||
}
|
||||
// Kiro IDE versions
|
||||
kiroVersions = []string{
|
||||
"0.10.32", "0.10.16", "0.10.10",
|
||||
"0.9.47", "0.9.40", "0.9.2",
|
||||
"0.8.206", "0.8.140", "0.8.135", "0.8.86",
|
||||
}
|
||||
// Global singleton
|
||||
globalFingerprintManager *FingerprintManager
|
||||
globalFingerprintManagerOnce sync.Once
|
||||
)
|
||||
|
||||
func GlobalFingerprintManager() *FingerprintManager {
|
||||
globalFingerprintManagerOnce.Do(func() {
|
||||
globalFingerprintManager = NewFingerprintManager()
|
||||
})
|
||||
return globalFingerprintManager
|
||||
}
|
||||
|
||||
func SetGlobalFingerprintConfig(cfg *FingerprintConfig) {
|
||||
GlobalFingerprintManager().SetConfig(cfg)
|
||||
}
|
||||
|
||||
// SetConfig applies the config and clears the fingerprint cache.
|
||||
func (fm *FingerprintManager) SetConfig(cfg *FingerprintConfig) {
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
fm.config = cfg
|
||||
// Clear cached fingerprints so they regenerate with the new config
|
||||
fm.fingerprints = make(map[string]*Fingerprint)
|
||||
}
|
||||
|
||||
func NewFingerprintManager() *FingerprintManager {
|
||||
return &FingerprintManager{
|
||||
fingerprints: make(map[string]*Fingerprint),
|
||||
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
// GetFingerprint returns the fingerprint for tokenKey, creating one if it doesn't exist.
|
||||
func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
|
||||
fm.mu.RLock()
|
||||
if fp, exists := fm.fingerprints[tokenKey]; exists {
|
||||
fm.mu.RUnlock()
|
||||
return fp
|
||||
}
|
||||
fm.mu.RUnlock()
|
||||
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
|
||||
if fp, exists := fm.fingerprints[tokenKey]; exists {
|
||||
return fp
|
||||
}
|
||||
|
||||
fp := fm.generateFingerprint(tokenKey)
|
||||
fm.fingerprints[tokenKey] = fp
|
||||
return fp
|
||||
}
|
||||
|
||||
func (fm *FingerprintManager) generateFingerprint(tokenKey string) *Fingerprint {
|
||||
if fm.config != nil {
|
||||
return fm.generateFromConfig(tokenKey)
|
||||
}
|
||||
return fm.generateRandom(tokenKey)
|
||||
}
|
||||
|
||||
// generateFromConfig uses config values, falling back to random for empty fields.
|
||||
func (fm *FingerprintManager) generateFromConfig(tokenKey string) *Fingerprint {
|
||||
cfg := fm.config
|
||||
|
||||
// Helper: config value or random selection
|
||||
configOrRandom := func(configVal string, choices []string) string {
|
||||
if configVal != "" {
|
||||
return configVal
|
||||
}
|
||||
return choices[fm.rng.Intn(len(choices))]
|
||||
}
|
||||
|
||||
osType := cfg.OSType
|
||||
if osType == "" {
|
||||
osType = runtime.GOOS
|
||||
if !slices.Contains(osTypes, osType) {
|
||||
osType = osTypes[fm.rng.Intn(len(osTypes))]
|
||||
}
|
||||
}
|
||||
|
||||
osVersion := cfg.OSVersion
|
||||
if osVersion == "" {
|
||||
if versions, ok := osVersions[osType]; ok {
|
||||
osVersion = versions[fm.rng.Intn(len(versions))]
|
||||
}
|
||||
}
|
||||
|
||||
kiroHash := cfg.KiroHash
|
||||
if kiroHash == "" {
|
||||
hash := sha256.Sum256([]byte(tokenKey))
|
||||
kiroHash = hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
return &Fingerprint{
|
||||
OIDCSDKVersion: configOrRandom(cfg.OIDCSDKVersion, oidcSDKVersions),
|
||||
RuntimeSDKVersion: configOrRandom(cfg.RuntimeSDKVersion, runtimeSDKVersions),
|
||||
StreamingSDKVersion: configOrRandom(cfg.StreamingSDKVersion, streamingSDKVersions),
|
||||
OSType: osType,
|
||||
OSVersion: osVersion,
|
||||
NodeVersion: configOrRandom(cfg.NodeVersion, nodeVersions),
|
||||
KiroVersion: configOrRandom(cfg.KiroVersion, kiroVersions),
|
||||
KiroHash: kiroHash,
|
||||
}
|
||||
}
|
||||
|
||||
// generateRandom generates a deterministic fingerprint seeded by accountKey hash.
|
||||
func (fm *FingerprintManager) generateRandom(accountKey string) *Fingerprint {
|
||||
// Use accountKey hash as seed for deterministic random selection
|
||||
hash := sha256.Sum256([]byte(accountKey))
|
||||
seed := int64(binary.BigEndian.Uint64(hash[:8]))
|
||||
rng := rand.New(rand.NewSource(seed))
|
||||
|
||||
osType := runtime.GOOS
|
||||
if !slices.Contains(osTypes, osType) {
|
||||
osType = osTypes[rng.Intn(len(osTypes))]
|
||||
}
|
||||
osVersion := osVersions[osType][rng.Intn(len(osVersions[osType]))]
|
||||
|
||||
return &Fingerprint{
|
||||
OIDCSDKVersion: oidcSDKVersions[rng.Intn(len(oidcSDKVersions))],
|
||||
RuntimeSDKVersion: runtimeSDKVersions[rng.Intn(len(runtimeSDKVersions))],
|
||||
StreamingSDKVersion: streamingSDKVersions[rng.Intn(len(streamingSDKVersions))],
|
||||
OSType: osType,
|
||||
OSVersion: osVersion,
|
||||
NodeVersion: nodeVersions[rng.Intn(len(nodeVersions))],
|
||||
KiroVersion: kiroVersions[rng.Intn(len(kiroVersions))],
|
||||
KiroHash: hex.EncodeToString(hash[:]),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateAccountKey returns a 16-char hex key derived from SHA256(seed).
|
||||
func GenerateAccountKey(seed string) string {
|
||||
hash := sha256.Sum256([]byte(seed))
|
||||
return hex.EncodeToString(hash[:8])
|
||||
}
|
||||
|
||||
// GetAccountKey derives an account key from clientID > refreshToken > random UUID.
|
||||
func GetAccountKey(clientID, refreshToken string) string {
|
||||
// 1. Prefer ClientID
|
||||
if clientID != "" {
|
||||
return GenerateAccountKey(clientID)
|
||||
}
|
||||
|
||||
// 2. Fallback to RefreshToken
|
||||
if refreshToken != "" {
|
||||
return GenerateAccountKey(refreshToken)
|
||||
}
|
||||
|
||||
// 3. Random fallback
|
||||
return GenerateAccountKey(uuid.New().String())
|
||||
}
|
||||
|
||||
// BuildUserAgent format: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
|
||||
func (fp *Fingerprint) BuildUserAgent() string {
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
|
||||
fp.StreamingSDKVersion,
|
||||
fp.OSType,
|
||||
fp.OSVersion,
|
||||
fp.NodeVersion,
|
||||
fp.StreamingSDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
|
||||
// BuildAmzUserAgent format: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
|
||||
func (fp *Fingerprint) BuildAmzUserAgent() string {
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s KiroIDE-%s-%s",
|
||||
fp.StreamingSDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
|
||||
func SetOIDCHeaders(req *http.Request) {
|
||||
fp := GlobalFingerprintManager().GetFingerprint("oidc-session")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/%s KiroIDE", fp.OIDCSDKVersion))
|
||||
req.Header.Set("User-Agent", fmt.Sprintf(
|
||||
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/%s#%s m/E KiroIDE",
|
||||
fp.OIDCSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, "sso-oidc", fp.OIDCSDKVersion))
|
||||
req.Header.Set("amz-sdk-invocation-id", uuid.New().String())
|
||||
req.Header.Set("amz-sdk-request", "attempt=1; max=4")
|
||||
}
|
||||
|
||||
func setRuntimeHeaders(req *http.Request, accessToken string, accountKey string) {
|
||||
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
|
||||
machineID := fp.KiroHash
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s",
|
||||
fp.RuntimeSDKVersion, fp.KiroVersion, machineID))
|
||||
req.Header.Set("User-Agent", fmt.Sprintf(
|
||||
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererruntime#%s m/N,E KiroIDE-%s-%s",
|
||||
fp.RuntimeSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, fp.RuntimeSDKVersion,
|
||||
fp.KiroVersion, machineID))
|
||||
req.Header.Set("amz-sdk-invocation-id", uuid.New().String())
|
||||
req.Header.Set("amz-sdk-request", "attempt=1; max=1")
|
||||
}
|
||||
778
internal/auth/kiro/fingerprint_test.go
Normal file
778
internal/auth/kiro/fingerprint_test.go
Normal file
@@ -0,0 +1,778 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewFingerprintManager(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
if fm == nil {
|
||||
t.Fatal("expected non-nil FingerprintManager")
|
||||
}
|
||||
if fm.fingerprints == nil {
|
||||
t.Error("expected non-nil fingerprints map")
|
||||
}
|
||||
if fm.rng == nil {
|
||||
t.Error("expected non-nil rng")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_NewToken(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
if fp == nil {
|
||||
t.Fatal("expected non-nil Fingerprint")
|
||||
}
|
||||
if fp.OIDCSDKVersion == "" {
|
||||
t.Error("expected non-empty OIDCSDKVersion")
|
||||
}
|
||||
if fp.RuntimeSDKVersion == "" {
|
||||
t.Error("expected non-empty RuntimeSDKVersion")
|
||||
}
|
||||
if fp.StreamingSDKVersion == "" {
|
||||
t.Error("expected non-empty StreamingSDKVersion")
|
||||
}
|
||||
if fp.OSType == "" {
|
||||
t.Error("expected non-empty OSType")
|
||||
}
|
||||
if fp.OSVersion == "" {
|
||||
t.Error("expected non-empty OSVersion")
|
||||
}
|
||||
if fp.NodeVersion == "" {
|
||||
t.Error("expected non-empty NodeVersion")
|
||||
}
|
||||
if fp.KiroVersion == "" {
|
||||
t.Error("expected non-empty KiroVersion")
|
||||
}
|
||||
if fp.KiroHash == "" {
|
||||
t.Error("expected non-empty KiroHash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_SameTokenReturnsSameFingerprint(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
fp2 := fm.GetFingerprint("token1")
|
||||
|
||||
if fp1 != fp2 {
|
||||
t.Error("expected same fingerprint for same token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_DifferentTokens(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
fp2 := fm.GetFingerprint("token2")
|
||||
|
||||
if fp1 == fp2 {
|
||||
t.Error("expected different fingerprints for different tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildUserAgent(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
ua := fp.BuildUserAgent()
|
||||
if ua == "" {
|
||||
t.Error("expected non-empty User-Agent")
|
||||
}
|
||||
|
||||
amzUA := fp.BuildAmzUserAgent()
|
||||
if amzUA == "" {
|
||||
t.Error("expected non-empty X-Amz-User-Agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_OSVersionMatchesOSType(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
fp := fm.GetFingerprint("token" + string(rune('a'+i)))
|
||||
validVersions := osVersions[fp.OSType]
|
||||
found := false
|
||||
for _, v := range validVersions {
|
||||
if v == fp.OSVersion {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("OS version %s not valid for OS type %s", fp.OSVersion, fp.OSType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateFromConfig_OSTypeFromRuntimeGOOS(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
// Set config with empty OSType to trigger runtime.GOOS fallback
|
||||
fm.SetConfig(&FingerprintConfig{
|
||||
OIDCSDKVersion: "3.738.0", // Set other fields to use config path
|
||||
})
|
||||
|
||||
fp := fm.GetFingerprint("test-token")
|
||||
|
||||
// Expected OS type based on runtime.GOOS mapping
|
||||
var expectedOS string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
expectedOS = "darwin"
|
||||
case "windows":
|
||||
expectedOS = "windows"
|
||||
default:
|
||||
expectedOS = "linux"
|
||||
}
|
||||
|
||||
if fp.OSType != expectedOS {
|
||||
t.Errorf("expected OSType '%s' from runtime.GOOS '%s', got '%s'",
|
||||
expectedOS, runtime.GOOS, fp.OSType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
const numGoroutines = 100
|
||||
const numOperations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := range numGoroutines {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := range numOperations {
|
||||
tokenKey := "token" + string(rune('a'+id%26))
|
||||
switch j % 2 {
|
||||
case 0:
|
||||
fm.GetFingerprint(tokenKey)
|
||||
case 1:
|
||||
fp := fm.GetFingerprint(tokenKey)
|
||||
_ = fp.BuildUserAgent()
|
||||
_ = fp.BuildAmzUserAgent()
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestKiroHashStability(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
// Same token should always return same hash
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
fp2 := fm.GetFingerprint("token1")
|
||||
if fp1.KiroHash != fp2.KiroHash {
|
||||
t.Errorf("same token should have same hash: %s vs %s", fp1.KiroHash, fp2.KiroHash)
|
||||
}
|
||||
|
||||
// Different tokens should have different hashes
|
||||
fp3 := fm.GetFingerprint("token2")
|
||||
if fp1.KiroHash == fp3.KiroHash {
|
||||
t.Errorf("different tokens should have different hashes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKiroHashFormat(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
if len(fp.KiroHash) != 64 {
|
||||
t.Errorf("expected KiroHash length 64 (SHA256 hex), got %d", len(fp.KiroHash))
|
||||
}
|
||||
|
||||
for _, c := range fp.KiroHash {
|
||||
if (c < '0' || c > '9') && (c < 'a' || c > 'f') {
|
||||
t.Errorf("invalid hex character in KiroHash: %c", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalFingerprintManager(t *testing.T) {
|
||||
fm1 := GlobalFingerprintManager()
|
||||
fm2 := GlobalFingerprintManager()
|
||||
|
||||
if fm1 == nil {
|
||||
t.Fatal("expected non-nil GlobalFingerprintManager")
|
||||
}
|
||||
if fm1 != fm2 {
|
||||
t.Error("expected GlobalFingerprintManager to return same instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetOIDCHeaders(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
SetOIDCHeaders(req)
|
||||
|
||||
if req.Header.Get("Content-Type") != "application/json" {
|
||||
t.Error("expected Content-Type header to be set")
|
||||
}
|
||||
|
||||
amzUA := req.Header.Get("x-amz-user-agent")
|
||||
if amzUA == "" {
|
||||
t.Error("expected x-amz-user-agent header to be set")
|
||||
}
|
||||
if !strings.Contains(amzUA, "aws-sdk-js/") {
|
||||
t.Errorf("x-amz-user-agent should contain aws-sdk-js: %s", amzUA)
|
||||
}
|
||||
if !strings.Contains(amzUA, "KiroIDE") {
|
||||
t.Errorf("x-amz-user-agent should contain KiroIDE: %s", amzUA)
|
||||
}
|
||||
|
||||
ua := req.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
t.Error("expected User-Agent header to be set")
|
||||
}
|
||||
if !strings.Contains(ua, "api/sso-oidc") {
|
||||
t.Errorf("User-Agent should contain api name: %s", ua)
|
||||
}
|
||||
|
||||
if req.Header.Get("amz-sdk-invocation-id") == "" {
|
||||
t.Error("expected amz-sdk-invocation-id header to be set")
|
||||
}
|
||||
if req.Header.Get("amz-sdk-request") != "attempt=1; max=4" {
|
||||
t.Errorf("unexpected amz-sdk-request header: %s", req.Header.Get("amz-sdk-request"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
endpoint string
|
||||
path string
|
||||
queryParams map[string]string
|
||||
want string
|
||||
wantContains []string
|
||||
}{
|
||||
{
|
||||
name: "no query params",
|
||||
endpoint: "https://api.example.com",
|
||||
path: "getUsageLimits",
|
||||
queryParams: nil,
|
||||
want: "https://api.example.com/getUsageLimits",
|
||||
},
|
||||
{
|
||||
name: "empty query params",
|
||||
endpoint: "https://api.example.com",
|
||||
path: "getUsageLimits",
|
||||
queryParams: map[string]string{},
|
||||
want: "https://api.example.com/getUsageLimits",
|
||||
},
|
||||
{
|
||||
name: "single query param",
|
||||
endpoint: "https://api.example.com",
|
||||
path: "getUsageLimits",
|
||||
queryParams: map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
},
|
||||
want: "https://api.example.com/getUsageLimits?origin=AI_EDITOR",
|
||||
},
|
||||
{
|
||||
name: "multiple query params",
|
||||
endpoint: "https://api.example.com",
|
||||
path: "getUsageLimits",
|
||||
queryParams: map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
"profileArn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEF",
|
||||
},
|
||||
wantContains: []string{
|
||||
"https://api.example.com/getUsageLimits?",
|
||||
"origin=AI_EDITOR",
|
||||
"profileArn=arn%3Aaws%3Acodewhisperer%3Aus-east-1%3A123456789012%3Aprofile%2FABCDEF",
|
||||
"resourceType=AGENTIC_REQUEST",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "omit empty params",
|
||||
endpoint: "https://api.example.com",
|
||||
path: "getUsageLimits",
|
||||
queryParams: map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
"profileArn": "",
|
||||
},
|
||||
want: "https://api.example.com/getUsageLimits?origin=AI_EDITOR",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := buildURL(tt.endpoint, tt.path, tt.queryParams)
|
||||
if tt.want != "" {
|
||||
if got != tt.want {
|
||||
t.Errorf("buildURL() = %v, want %v", got, tt.want)
|
||||
}
|
||||
}
|
||||
if tt.wantContains != nil {
|
||||
for _, substr := range tt.wantContains {
|
||||
if !strings.Contains(got, substr) {
|
||||
t.Errorf("buildURL() = %v, want to contain %v", got, substr)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildUserAgentFormat(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
ua := fp.BuildUserAgent()
|
||||
requiredParts := []string{
|
||||
"aws-sdk-js/",
|
||||
"ua/2.1",
|
||||
"os/",
|
||||
"lang/js",
|
||||
"md/nodejs#",
|
||||
"api/codewhispererstreaming#",
|
||||
"m/E",
|
||||
"KiroIDE-",
|
||||
}
|
||||
for _, part := range requiredParts {
|
||||
if !strings.Contains(ua, part) {
|
||||
t.Errorf("User-Agent missing required part %q: %s", part, ua)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAmzUserAgentFormat(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
amzUA := fp.BuildAmzUserAgent()
|
||||
requiredParts := []string{
|
||||
"aws-sdk-js/",
|
||||
"KiroIDE-",
|
||||
}
|
||||
for _, part := range requiredParts {
|
||||
if !strings.Contains(amzUA, part) {
|
||||
t.Errorf("X-Amz-User-Agent missing required part %q: %s", part, amzUA)
|
||||
}
|
||||
}
|
||||
|
||||
// Amz-User-Agent should be shorter than User-Agent
|
||||
ua := fp.BuildUserAgent()
|
||||
if len(amzUA) >= len(ua) {
|
||||
t.Error("X-Amz-User-Agent should be shorter than User-Agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetRuntimeHeaders(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
accessToken := "test-access-token-1234567890"
|
||||
clientID := "test-client-id-12345"
|
||||
accountKey := GenerateAccountKey(clientID)
|
||||
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
|
||||
machineID := fp.KiroHash
|
||||
|
||||
setRuntimeHeaders(req, accessToken, accountKey)
|
||||
|
||||
// Check Authorization header
|
||||
if req.Header.Get("Authorization") != "Bearer "+accessToken {
|
||||
t.Errorf("expected Authorization header 'Bearer %s', got '%s'", accessToken, req.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
// Check x-amz-user-agent header
|
||||
amzUA := req.Header.Get("x-amz-user-agent")
|
||||
if amzUA == "" {
|
||||
t.Error("expected x-amz-user-agent header to be set")
|
||||
}
|
||||
if !strings.Contains(amzUA, "aws-sdk-js/") {
|
||||
t.Errorf("x-amz-user-agent should contain aws-sdk-js: %s", amzUA)
|
||||
}
|
||||
if !strings.Contains(amzUA, "KiroIDE-") {
|
||||
t.Errorf("x-amz-user-agent should contain KiroIDE: %s", amzUA)
|
||||
}
|
||||
if !strings.Contains(amzUA, machineID) {
|
||||
t.Errorf("x-amz-user-agent should contain machineID: %s", amzUA)
|
||||
}
|
||||
|
||||
// Check User-Agent header
|
||||
ua := req.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
t.Error("expected User-Agent header to be set")
|
||||
}
|
||||
if !strings.Contains(ua, "api/codewhispererruntime#") {
|
||||
t.Errorf("User-Agent should contain api/codewhispererruntime: %s", ua)
|
||||
}
|
||||
if !strings.Contains(ua, "m/N,E") {
|
||||
t.Errorf("User-Agent should contain m/N,E: %s", ua)
|
||||
}
|
||||
|
||||
// Check amz-sdk-invocation-id (should be a UUID)
|
||||
invocationID := req.Header.Get("amz-sdk-invocation-id")
|
||||
if invocationID == "" {
|
||||
t.Error("expected amz-sdk-invocation-id header to be set")
|
||||
}
|
||||
if len(invocationID) != 36 {
|
||||
t.Errorf("expected amz-sdk-invocation-id to be UUID (36 chars), got %d", len(invocationID))
|
||||
}
|
||||
|
||||
// Check amz-sdk-request
|
||||
if req.Header.Get("amz-sdk-request") != "attempt=1; max=1" {
|
||||
t.Errorf("unexpected amz-sdk-request header: %s", req.Header.Get("amz-sdk-request"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSDKVersionsAreValid(t *testing.T) {
|
||||
// Verify all OIDC SDK versions match expected format (3.xxx.x)
|
||||
for _, v := range oidcSDKVersions {
|
||||
if !strings.HasPrefix(v, "3.") {
|
||||
t.Errorf("OIDC SDK version should start with 3.: %s", v)
|
||||
}
|
||||
parts := strings.Split(v, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("OIDC SDK version should have 3 parts: %s", v)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range runtimeSDKVersions {
|
||||
parts := strings.Split(v, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("Runtime SDK version should have 3 parts: %s", v)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range streamingSDKVersions {
|
||||
parts := strings.Split(v, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("Streaming SDK version should have 3 parts: %s", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestKiroVersionsAreValid(t *testing.T) {
|
||||
// Verify all Kiro versions match expected format (0.x.xxx)
|
||||
for _, v := range kiroVersions {
|
||||
if !strings.HasPrefix(v, "0.") {
|
||||
t.Errorf("Kiro version should start with 0.: %s", v)
|
||||
}
|
||||
parts := strings.Split(v, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("Kiro version should have 3 parts: %s", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeVersionsAreValid(t *testing.T) {
|
||||
// Verify all Node versions match expected format (xx.xx.x)
|
||||
for _, v := range nodeVersions {
|
||||
parts := strings.Split(v, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("Node version should have 3 parts: %s", v)
|
||||
}
|
||||
// Should be Node 20.x or 22.x
|
||||
if !strings.HasPrefix(v, "20.") && !strings.HasPrefix(v, "22.") {
|
||||
t.Errorf("Node version should be 20.x or 22.x LTS: %s", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintManager_SetConfig(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
// Without config, should generate random fingerprint
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
if fp1 == nil {
|
||||
t.Fatal("expected non-nil fingerprint")
|
||||
}
|
||||
|
||||
// Set config with all fields
|
||||
cfg := &FingerprintConfig{
|
||||
OIDCSDKVersion: "3.999.0",
|
||||
RuntimeSDKVersion: "9.9.9",
|
||||
StreamingSDKVersion: "8.8.8",
|
||||
OSType: "darwin",
|
||||
OSVersion: "99.0.0",
|
||||
NodeVersion: "99.99.99",
|
||||
KiroVersion: "9.9.999",
|
||||
KiroHash: "customhash123",
|
||||
}
|
||||
fm.SetConfig(cfg)
|
||||
|
||||
// After setting config, should use config values
|
||||
fp2 := fm.GetFingerprint("token2")
|
||||
if fp2.OIDCSDKVersion != "3.999.0" {
|
||||
t.Errorf("expected OIDCSDKVersion '3.999.0', got '%s'", fp2.OIDCSDKVersion)
|
||||
}
|
||||
if fp2.RuntimeSDKVersion != "9.9.9" {
|
||||
t.Errorf("expected RuntimeSDKVersion '9.9.9', got '%s'", fp2.RuntimeSDKVersion)
|
||||
}
|
||||
if fp2.StreamingSDKVersion != "8.8.8" {
|
||||
t.Errorf("expected StreamingSDKVersion '8.8.8', got '%s'", fp2.StreamingSDKVersion)
|
||||
}
|
||||
if fp2.OSType != "darwin" {
|
||||
t.Errorf("expected OSType 'darwin', got '%s'", fp2.OSType)
|
||||
}
|
||||
if fp2.OSVersion != "99.0.0" {
|
||||
t.Errorf("expected OSVersion '99.0.0', got '%s'", fp2.OSVersion)
|
||||
}
|
||||
if fp2.NodeVersion != "99.99.99" {
|
||||
t.Errorf("expected NodeVersion '99.99.99', got '%s'", fp2.NodeVersion)
|
||||
}
|
||||
if fp2.KiroVersion != "9.9.999" {
|
||||
t.Errorf("expected KiroVersion '9.9.999', got '%s'", fp2.KiroVersion)
|
||||
}
|
||||
if fp2.KiroHash != "customhash123" {
|
||||
t.Errorf("expected KiroHash 'customhash123', got '%s'", fp2.KiroHash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintManager_SetConfig_PartialFields(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
// Set config with only some fields
|
||||
cfg := &FingerprintConfig{
|
||||
KiroVersion: "1.2.345",
|
||||
KiroHash: "myhash",
|
||||
// Other fields empty - should use random
|
||||
}
|
||||
fm.SetConfig(cfg)
|
||||
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
// Configured fields should use config values
|
||||
if fp.KiroVersion != "1.2.345" {
|
||||
t.Errorf("expected KiroVersion '1.2.345', got '%s'", fp.KiroVersion)
|
||||
}
|
||||
if fp.KiroHash != "myhash" {
|
||||
t.Errorf("expected KiroHash 'myhash', got '%s'", fp.KiroHash)
|
||||
}
|
||||
|
||||
// Empty fields should be randomly selected (non-empty)
|
||||
if fp.OIDCSDKVersion == "" {
|
||||
t.Error("expected non-empty OIDCSDKVersion")
|
||||
}
|
||||
if fp.OSType == "" {
|
||||
t.Error("expected non-empty OSType")
|
||||
}
|
||||
if fp.NodeVersion == "" {
|
||||
t.Error("expected non-empty NodeVersion")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintManager_SetConfig_ClearsCache(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
// Get fingerprint before config
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
originalHash := fp1.KiroHash
|
||||
|
||||
// Set config
|
||||
cfg := &FingerprintConfig{
|
||||
KiroHash: "newcustomhash",
|
||||
}
|
||||
fm.SetConfig(cfg)
|
||||
|
||||
// Same token should now return different fingerprint (cache cleared)
|
||||
fp2 := fm.GetFingerprint("token1")
|
||||
if fp2.KiroHash == originalHash {
|
||||
t.Error("expected cache to be cleared after SetConfig")
|
||||
}
|
||||
if fp2.KiroHash != "newcustomhash" {
|
||||
t.Errorf("expected KiroHash 'newcustomhash', got '%s'", fp2.KiroHash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAccountKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
seed string
|
||||
check func(t *testing.T, result string)
|
||||
}{
|
||||
{
|
||||
name: "Empty seed",
|
||||
seed: "",
|
||||
check: func(t *testing.T, result string) {
|
||||
if result == "" {
|
||||
t.Error("expected non-empty result for empty seed")
|
||||
}
|
||||
if len(result) != 16 {
|
||||
t.Errorf("expected 16 char hex string, got %d chars", len(result))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Simple seed",
|
||||
seed: "test-client-id",
|
||||
check: func(t *testing.T, result string) {
|
||||
if len(result) != 16 {
|
||||
t.Errorf("expected 16 char hex string, got %d chars", len(result))
|
||||
}
|
||||
// Verify it's valid hex
|
||||
for _, c := range result {
|
||||
if (c < '0' || c > '9') && (c < 'a' || c > 'f') {
|
||||
t.Errorf("invalid hex character: %c", c)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Same seed produces same result",
|
||||
seed: "deterministic-seed",
|
||||
check: func(t *testing.T, result string) {
|
||||
result2 := GenerateAccountKey("deterministic-seed")
|
||||
if result != result2 {
|
||||
t.Errorf("same seed should produce same result: %s vs %s", result, result2)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Different seeds produce different results",
|
||||
seed: "seed-one",
|
||||
check: func(t *testing.T, result string) {
|
||||
result2 := GenerateAccountKey("seed-two")
|
||||
if result == result2 {
|
||||
t.Errorf("different seeds should produce different results: %s vs %s", result, result2)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GenerateAccountKey(tt.seed)
|
||||
tt.check(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
clientID string
|
||||
refreshToken string
|
||||
check func(t *testing.T, result string)
|
||||
}{
|
||||
{
|
||||
name: "Priority 1: clientID when both provided",
|
||||
clientID: "client-id-123",
|
||||
refreshToken: "refresh-token-456",
|
||||
check: func(t *testing.T, result string) {
|
||||
expected := GenerateAccountKey("client-id-123")
|
||||
if result != expected {
|
||||
t.Errorf("expected clientID-based key %s, got %s", expected, result)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Priority 2: refreshToken when clientID is empty",
|
||||
clientID: "",
|
||||
refreshToken: "refresh-token-789",
|
||||
check: func(t *testing.T, result string) {
|
||||
expected := GenerateAccountKey("refresh-token-789")
|
||||
if result != expected {
|
||||
t.Errorf("expected refreshToken-based key %s, got %s", expected, result)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Priority 3: random when both empty",
|
||||
clientID: "",
|
||||
refreshToken: "",
|
||||
check: func(t *testing.T, result string) {
|
||||
if len(result) != 16 {
|
||||
t.Errorf("expected 16 char key, got %d chars", len(result))
|
||||
}
|
||||
// Should be different each time (random UUID)
|
||||
result2 := GetAccountKey("", "")
|
||||
if result == result2 {
|
||||
t.Log("warning: random keys are the same (possible but unlikely)")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "clientID only",
|
||||
clientID: "solo-client-id",
|
||||
refreshToken: "",
|
||||
check: func(t *testing.T, result string) {
|
||||
expected := GenerateAccountKey("solo-client-id")
|
||||
if result != expected {
|
||||
t.Errorf("expected clientID-based key %s, got %s", expected, result)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "refreshToken only",
|
||||
clientID: "",
|
||||
refreshToken: "solo-refresh-token",
|
||||
check: func(t *testing.T, result string) {
|
||||
expected := GenerateAccountKey("solo-refresh-token")
|
||||
if result != expected {
|
||||
t.Errorf("expected refreshToken-based key %s, got %s", expected, result)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetAccountKey(tt.clientID, tt.refreshToken)
|
||||
tt.check(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountKey_Deterministic(t *testing.T) {
|
||||
// Verify that GetAccountKey produces deterministic results for same inputs
|
||||
clientID := "test-client-id-abc"
|
||||
refreshToken := "test-refresh-token-xyz"
|
||||
|
||||
// Call multiple times with same inputs
|
||||
results := make([]string, 10)
|
||||
for i := range 10 {
|
||||
results[i] = GetAccountKey(clientID, refreshToken)
|
||||
}
|
||||
|
||||
// All results should be identical
|
||||
for i := 1; i < 10; i++ {
|
||||
if results[i] != results[0] {
|
||||
t.Errorf("GetAccountKey should be deterministic: got %s and %s", results[0], results[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintDeterministic(t *testing.T) {
|
||||
// Verify that fingerprints are deterministic based on accountKey
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
accountKey := GenerateAccountKey("test-client-id")
|
||||
|
||||
// Get fingerprint multiple times
|
||||
fp1 := fm.GetFingerprint(accountKey)
|
||||
fp2 := fm.GetFingerprint(accountKey)
|
||||
|
||||
// Should be the same pointer (cached)
|
||||
if fp1 != fp2 {
|
||||
t.Error("expected same fingerprint pointer for same key")
|
||||
}
|
||||
|
||||
// Create new manager and verify same values
|
||||
fm2 := NewFingerprintManager()
|
||||
fp3 := fm2.GetFingerprint(accountKey)
|
||||
|
||||
// Values should be identical (deterministic generation)
|
||||
if fp1.KiroHash != fp3.KiroHash {
|
||||
t.Errorf("KiroHash should be deterministic: %s vs %s", fp1.KiroHash, fp3.KiroHash)
|
||||
}
|
||||
if fp1.OSType != fp3.OSType {
|
||||
t.Errorf("OSType should be deterministic: %s vs %s", fp1.OSType, fp3.OSType)
|
||||
}
|
||||
if fp1.OSVersion != fp3.OSVersion {
|
||||
t.Errorf("OSVersion should be deterministic: %s vs %s", fp1.OSVersion, fp3.OSVersion)
|
||||
}
|
||||
if fp1.KiroVersion != fp3.KiroVersion {
|
||||
t.Errorf("KiroVersion should be deterministic: %s vs %s", fp1.KiroVersion, fp3.KiroVersion)
|
||||
}
|
||||
if fp1.NodeVersion != fp3.NodeVersion {
|
||||
t.Errorf("NodeVersion should be deterministic: %s vs %s", fp1.NodeVersion, fp3.NodeVersion)
|
||||
}
|
||||
}
|
||||
174
internal/auth/kiro/jitter.go
Normal file
174
internal/auth/kiro/jitter.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Jitter configuration constants
|
||||
const (
|
||||
// JitterPercent is the default percentage of jitter to apply (±30%)
|
||||
JitterPercent = 0.30
|
||||
|
||||
// Human-like delay ranges
|
||||
ShortDelayMin = 50 * time.Millisecond // Minimum for rapid consecutive operations
|
||||
ShortDelayMax = 200 * time.Millisecond // Maximum for rapid consecutive operations
|
||||
NormalDelayMin = 1 * time.Second // Minimum for normal thinking time
|
||||
NormalDelayMax = 3 * time.Second // Maximum for normal thinking time
|
||||
LongDelayMin = 5 * time.Second // Minimum for reading/resting
|
||||
LongDelayMax = 10 * time.Second // Maximum for reading/resting
|
||||
|
||||
// Probability thresholds for human-like behavior
|
||||
ShortDelayProbability = 0.20 // 20% chance of short delay (consecutive ops)
|
||||
LongDelayProbability = 0.05 // 5% chance of long delay (reading/resting)
|
||||
NormalDelayProbability = 0.75 // 75% chance of normal delay (thinking)
|
||||
)
|
||||
|
||||
var (
|
||||
jitterRand *rand.Rand
|
||||
jitterRandOnce sync.Once
|
||||
jitterMu sync.Mutex
|
||||
lastRequestTime time.Time
|
||||
)
|
||||
|
||||
// initJitterRand initializes the random number generator for jitter calculations.
|
||||
// Uses a time-based seed for unpredictable but reproducible randomness.
|
||||
func initJitterRand() {
|
||||
jitterRandOnce.Do(func() {
|
||||
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
})
|
||||
}
|
||||
|
||||
// RandomDelay generates a random delay between min and max duration.
|
||||
// Thread-safe implementation using mutex protection.
|
||||
func RandomDelay(min, max time.Duration) time.Duration {
|
||||
initJitterRand()
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
|
||||
if min >= max {
|
||||
return min
|
||||
}
|
||||
|
||||
rangeMs := max.Milliseconds() - min.Milliseconds()
|
||||
randomMs := jitterRand.Int63n(rangeMs)
|
||||
return min + time.Duration(randomMs)*time.Millisecond
|
||||
}
|
||||
|
||||
// JitterDelay adds jitter to a base delay.
|
||||
// Applies ±jitterPercent variation to the base delay.
|
||||
// For example, JitterDelay(1*time.Second, 0.30) returns a value between 700ms and 1300ms.
|
||||
func JitterDelay(baseDelay time.Duration, jitterPercent float64) time.Duration {
|
||||
initJitterRand()
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
|
||||
if jitterPercent <= 0 || jitterPercent > 1 {
|
||||
jitterPercent = JitterPercent
|
||||
}
|
||||
|
||||
// Calculate jitter range: base * jitterPercent
|
||||
jitterRange := float64(baseDelay) * jitterPercent
|
||||
|
||||
// Generate random value in range [-jitterRange, +jitterRange]
|
||||
jitter := (jitterRand.Float64()*2 - 1) * jitterRange
|
||||
|
||||
result := time.Duration(float64(baseDelay) + jitter)
|
||||
if result < 0 {
|
||||
return 0
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// JitterDelayDefault applies the default ±30% jitter to a base delay.
|
||||
func JitterDelayDefault(baseDelay time.Duration) time.Duration {
|
||||
return JitterDelay(baseDelay, JitterPercent)
|
||||
}
|
||||
|
||||
// HumanLikeDelay generates a delay that mimics human behavior patterns.
|
||||
// The delay is selected based on probability distribution:
|
||||
// - 20% chance: Short delay (50-200ms) - simulates consecutive rapid operations
|
||||
// - 75% chance: Normal delay (1-3s) - simulates thinking/reading time
|
||||
// - 5% chance: Long delay (5-10s) - simulates breaks/reading longer content
|
||||
//
|
||||
// Returns the delay duration (caller should call time.Sleep with this value).
|
||||
func HumanLikeDelay() time.Duration {
|
||||
initJitterRand()
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
|
||||
// Track time since last request for adaptive behavior
|
||||
now := time.Now()
|
||||
timeSinceLastRequest := now.Sub(lastRequestTime)
|
||||
lastRequestTime = now
|
||||
|
||||
// If requests are very close together, use short delay
|
||||
if timeSinceLastRequest < 500*time.Millisecond && timeSinceLastRequest > 0 {
|
||||
rangeMs := ShortDelayMax.Milliseconds() - ShortDelayMin.Milliseconds()
|
||||
randomMs := jitterRand.Int63n(rangeMs)
|
||||
return ShortDelayMin + time.Duration(randomMs)*time.Millisecond
|
||||
}
|
||||
|
||||
// Otherwise, use probability-based selection
|
||||
roll := jitterRand.Float64()
|
||||
|
||||
var min, max time.Duration
|
||||
switch {
|
||||
case roll < ShortDelayProbability:
|
||||
// Short delay - consecutive operations
|
||||
min, max = ShortDelayMin, ShortDelayMax
|
||||
case roll < ShortDelayProbability+LongDelayProbability:
|
||||
// Long delay - reading/resting
|
||||
min, max = LongDelayMin, LongDelayMax
|
||||
default:
|
||||
// Normal delay - thinking time
|
||||
min, max = NormalDelayMin, NormalDelayMax
|
||||
}
|
||||
|
||||
rangeMs := max.Milliseconds() - min.Milliseconds()
|
||||
randomMs := jitterRand.Int63n(rangeMs)
|
||||
return min + time.Duration(randomMs)*time.Millisecond
|
||||
}
|
||||
|
||||
// ApplyHumanLikeDelay applies human-like delay by sleeping.
|
||||
// This is a convenience function that combines HumanLikeDelay with time.Sleep.
|
||||
func ApplyHumanLikeDelay() {
|
||||
delay := HumanLikeDelay()
|
||||
if delay > 0 {
|
||||
time.Sleep(delay)
|
||||
}
|
||||
}
|
||||
|
||||
// ExponentialBackoffWithJitter calculates retry delay using exponential backoff with jitter.
|
||||
// Formula: min(baseDelay * 2^attempt + jitter, maxDelay)
|
||||
// This helps prevent thundering herd problem when multiple clients retry simultaneously.
|
||||
func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
|
||||
if attempt < 0 {
|
||||
attempt = 0
|
||||
}
|
||||
|
||||
// Calculate exponential backoff: baseDelay * 2^attempt
|
||||
backoff := baseDelay * time.Duration(1<<uint(attempt))
|
||||
if backoff > maxDelay {
|
||||
backoff = maxDelay
|
||||
}
|
||||
|
||||
// Add ±30% jitter
|
||||
return JitterDelay(backoff, JitterPercent)
|
||||
}
|
||||
|
||||
// ShouldSkipDelay determines if delay should be skipped based on context.
|
||||
// Returns true for streaming responses, WebSocket connections, etc.
|
||||
// This function can be extended to check additional skip conditions.
|
||||
func ShouldSkipDelay(isStreaming bool) bool {
|
||||
return isStreaming
|
||||
}
|
||||
|
||||
// ResetLastRequestTime resets the last request time tracker.
|
||||
// Useful for testing or when starting a new session.
|
||||
func ResetLastRequestTime() {
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
lastRequestTime = time.Time{}
|
||||
}
|
||||
187
internal/auth/kiro/metrics.go
Normal file
187
internal/auth/kiro/metrics.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenMetrics holds performance metrics for a single token.
|
||||
type TokenMetrics struct {
|
||||
SuccessRate float64 // Success rate (0.0 - 1.0)
|
||||
AvgLatency float64 // Average latency in milliseconds
|
||||
QuotaRemaining float64 // Remaining quota (0.0 - 1.0)
|
||||
LastUsed time.Time // Last usage timestamp
|
||||
FailCount int // Consecutive failure count
|
||||
TotalRequests int // Total request count
|
||||
successCount int // Internal: successful request count
|
||||
totalLatency float64 // Internal: cumulative latency
|
||||
}
|
||||
|
||||
// TokenScorer manages token metrics and scoring.
|
||||
type TokenScorer struct {
|
||||
mu sync.RWMutex
|
||||
metrics map[string]*TokenMetrics
|
||||
|
||||
// Scoring weights
|
||||
successRateWeight float64
|
||||
quotaWeight float64
|
||||
latencyWeight float64
|
||||
lastUsedWeight float64
|
||||
failPenaltyMultiplier float64
|
||||
}
|
||||
|
||||
// NewTokenScorer creates a new TokenScorer with default weights.
|
||||
func NewTokenScorer() *TokenScorer {
|
||||
return &TokenScorer{
|
||||
metrics: make(map[string]*TokenMetrics),
|
||||
successRateWeight: 0.4,
|
||||
quotaWeight: 0.25,
|
||||
latencyWeight: 0.2,
|
||||
lastUsedWeight: 0.15,
|
||||
failPenaltyMultiplier: 0.1,
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateMetrics returns existing metrics or creates new ones.
|
||||
func (s *TokenScorer) getOrCreateMetrics(tokenKey string) *TokenMetrics {
|
||||
if m, ok := s.metrics[tokenKey]; ok {
|
||||
return m
|
||||
}
|
||||
m := &TokenMetrics{
|
||||
SuccessRate: 1.0,
|
||||
QuotaRemaining: 1.0,
|
||||
}
|
||||
s.metrics[tokenKey] = m
|
||||
return m
|
||||
}
|
||||
|
||||
// RecordRequest records the result of a request for a token.
|
||||
func (s *TokenScorer) RecordRequest(tokenKey string, success bool, latency time.Duration) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
m := s.getOrCreateMetrics(tokenKey)
|
||||
m.TotalRequests++
|
||||
m.LastUsed = time.Now()
|
||||
m.totalLatency += float64(latency.Milliseconds())
|
||||
|
||||
if success {
|
||||
m.successCount++
|
||||
m.FailCount = 0
|
||||
} else {
|
||||
m.FailCount++
|
||||
}
|
||||
|
||||
// Update derived metrics
|
||||
if m.TotalRequests > 0 {
|
||||
m.SuccessRate = float64(m.successCount) / float64(m.TotalRequests)
|
||||
m.AvgLatency = m.totalLatency / float64(m.TotalRequests)
|
||||
}
|
||||
}
|
||||
|
||||
// SetQuotaRemaining updates the remaining quota for a token.
|
||||
func (s *TokenScorer) SetQuotaRemaining(tokenKey string, quota float64) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
m := s.getOrCreateMetrics(tokenKey)
|
||||
m.QuotaRemaining = quota
|
||||
}
|
||||
|
||||
// GetMetrics returns a copy of the metrics for a token.
|
||||
func (s *TokenScorer) GetMetrics(tokenKey string) *TokenMetrics {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if m, ok := s.metrics[tokenKey]; ok {
|
||||
copy := *m
|
||||
return ©
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CalculateScore computes the score for a token (higher is better).
|
||||
func (s *TokenScorer) CalculateScore(tokenKey string) float64 {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
m, ok := s.metrics[tokenKey]
|
||||
if !ok {
|
||||
return 1.0 // New tokens get a high initial score
|
||||
}
|
||||
|
||||
// Success rate component (0-1)
|
||||
successScore := m.SuccessRate
|
||||
|
||||
// Quota component (0-1)
|
||||
quotaScore := m.QuotaRemaining
|
||||
|
||||
// Latency component (normalized, lower is better)
|
||||
// Using exponential decay: score = e^(-latency/1000)
|
||||
// 1000ms latency -> ~0.37 score, 100ms -> ~0.90 score
|
||||
latencyScore := math.Exp(-m.AvgLatency / 1000.0)
|
||||
if m.TotalRequests == 0 {
|
||||
latencyScore = 1.0
|
||||
}
|
||||
|
||||
// Last used component (prefer tokens not recently used)
|
||||
// Score increases as time since last use increases
|
||||
timeSinceUse := time.Since(m.LastUsed).Seconds()
|
||||
// Normalize: 60 seconds -> ~0.63 score, 0 seconds -> 0 score
|
||||
lastUsedScore := 1.0 - math.Exp(-timeSinceUse/60.0)
|
||||
if m.LastUsed.IsZero() {
|
||||
lastUsedScore = 1.0
|
||||
}
|
||||
|
||||
// Calculate weighted score
|
||||
score := s.successRateWeight*successScore +
|
||||
s.quotaWeight*quotaScore +
|
||||
s.latencyWeight*latencyScore +
|
||||
s.lastUsedWeight*lastUsedScore
|
||||
|
||||
// Apply consecutive failure penalty
|
||||
if m.FailCount > 0 {
|
||||
penalty := s.failPenaltyMultiplier * float64(m.FailCount)
|
||||
score = score * math.Max(0, 1.0-penalty)
|
||||
}
|
||||
|
||||
return score
|
||||
}
|
||||
|
||||
// SelectBestToken selects the token with the highest score.
|
||||
func (s *TokenScorer) SelectBestToken(tokens []string) string {
|
||||
if len(tokens) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(tokens) == 1 {
|
||||
return tokens[0]
|
||||
}
|
||||
|
||||
bestToken := tokens[0]
|
||||
bestScore := s.CalculateScore(tokens[0])
|
||||
|
||||
for _, token := range tokens[1:] {
|
||||
score := s.CalculateScore(token)
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
bestToken = token
|
||||
}
|
||||
}
|
||||
|
||||
return bestToken
|
||||
}
|
||||
|
||||
// ResetMetrics clears all metrics for a token.
|
||||
func (s *TokenScorer) ResetMetrics(tokenKey string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.metrics, tokenKey)
|
||||
}
|
||||
|
||||
// ResetAllMetrics clears all stored metrics.
|
||||
func (s *TokenScorer) ResetAllMetrics() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.metrics = make(map[string]*TokenMetrics)
|
||||
}
|
||||
301
internal/auth/kiro/metrics_test.go
Normal file
301
internal/auth/kiro/metrics_test.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewTokenScorer(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil TokenScorer")
|
||||
}
|
||||
if s.metrics == nil {
|
||||
t.Error("expected non-nil metrics map")
|
||||
}
|
||||
if s.successRateWeight != 0.4 {
|
||||
t.Errorf("expected successRateWeight 0.4, got %f", s.successRateWeight)
|
||||
}
|
||||
if s.quotaWeight != 0.25 {
|
||||
t.Errorf("expected quotaWeight 0.25, got %f", s.quotaWeight)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_Success(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m == nil {
|
||||
t.Fatal("expected non-nil metrics")
|
||||
}
|
||||
if m.TotalRequests != 1 {
|
||||
t.Errorf("expected TotalRequests 1, got %d", m.TotalRequests)
|
||||
}
|
||||
if m.SuccessRate != 1.0 {
|
||||
t.Errorf("expected SuccessRate 1.0, got %f", m.SuccessRate)
|
||||
}
|
||||
if m.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0, got %d", m.FailCount)
|
||||
}
|
||||
if m.AvgLatency != 100 {
|
||||
t.Errorf("expected AvgLatency 100, got %f", m.AvgLatency)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_Failure(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", false, 200*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.SuccessRate != 0.0 {
|
||||
t.Errorf("expected SuccessRate 0.0, got %f", m.SuccessRate)
|
||||
}
|
||||
if m.FailCount != 1 {
|
||||
t.Errorf("expected FailCount 1, got %d", m.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_MixedResults(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.TotalRequests != 4 {
|
||||
t.Errorf("expected TotalRequests 4, got %d", m.TotalRequests)
|
||||
}
|
||||
if m.SuccessRate != 0.75 {
|
||||
t.Errorf("expected SuccessRate 0.75, got %f", m.SuccessRate)
|
||||
}
|
||||
if m.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0 (reset on success), got %d", m.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_ConsecutiveFailures(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.FailCount != 3 {
|
||||
t.Errorf("expected FailCount 3, got %d", m.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetQuotaRemaining(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.SetQuotaRemaining("token1", 0.5)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.QuotaRemaining != 0.5 {
|
||||
t.Errorf("expected QuotaRemaining 0.5, got %f", m.QuotaRemaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMetrics_NonExistent(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
m := s.GetMetrics("nonexistent")
|
||||
if m != nil {
|
||||
t.Error("expected nil metrics for non-existent token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMetrics_ReturnsCopy(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m1 := s.GetMetrics("token1")
|
||||
m1.TotalRequests = 999
|
||||
|
||||
m2 := s.GetMetrics("token1")
|
||||
if m2.TotalRequests == 999 {
|
||||
t.Error("GetMetrics should return a copy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_NewToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
score := s.CalculateScore("newtoken")
|
||||
if score != 1.0 {
|
||||
t.Errorf("expected score 1.0 for new token, got %f", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_PerfectToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 50*time.Millisecond)
|
||||
s.SetQuotaRemaining("token1", 1.0)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
score := s.CalculateScore("token1")
|
||||
if score < 0.5 || score > 1.0 {
|
||||
t.Errorf("expected high score for perfect token, got %f", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_FailedToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
for i := 0; i < 5; i++ {
|
||||
s.RecordRequest("token1", false, 1000*time.Millisecond)
|
||||
}
|
||||
s.SetQuotaRemaining("token1", 0.1)
|
||||
|
||||
score := s.CalculateScore("token1")
|
||||
if score > 0.5 {
|
||||
t.Errorf("expected low score for failed token, got %f", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_FailPenalty(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
scoreNoFail := s.CalculateScore("token1")
|
||||
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
scoreWithFail := s.CalculateScore("token1")
|
||||
|
||||
if scoreWithFail >= scoreNoFail {
|
||||
t.Errorf("expected lower score with consecutive failures: noFail=%f, withFail=%f", scoreNoFail, scoreWithFail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBestToken_Empty(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
best := s.SelectBestToken([]string{})
|
||||
if best != "" {
|
||||
t.Errorf("expected empty string for empty tokens, got %s", best)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBestToken_SingleToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
best := s.SelectBestToken([]string{"token1"})
|
||||
if best != "token1" {
|
||||
t.Errorf("expected token1, got %s", best)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBestToken_MultipleTokens(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
|
||||
s.RecordRequest("bad", false, 1000*time.Millisecond)
|
||||
s.RecordRequest("bad", false, 1000*time.Millisecond)
|
||||
s.SetQuotaRemaining("bad", 0.1)
|
||||
|
||||
s.RecordRequest("good", true, 50*time.Millisecond)
|
||||
s.SetQuotaRemaining("good", 0.9)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
best := s.SelectBestToken([]string{"bad", "good"})
|
||||
if best != "good" {
|
||||
t.Errorf("expected good token to be selected, got %s", best)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetMetrics(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.ResetMetrics("token1")
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m != nil {
|
||||
t.Error("expected nil metrics after reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetAllMetrics(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token2", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token3", true, 100*time.Millisecond)
|
||||
|
||||
s.ResetAllMetrics()
|
||||
|
||||
if s.GetMetrics("token1") != nil {
|
||||
t.Error("expected nil metrics for token1 after reset all")
|
||||
}
|
||||
if s.GetMetrics("token2") != nil {
|
||||
t.Error("expected nil metrics for token2 after reset all")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenScorer_ConcurrentAccess(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
const numGoroutines = 50
|
||||
const numOperations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
tokenKey := "token" + string(rune('a'+id%10))
|
||||
for j := 0; j < numOperations; j++ {
|
||||
switch j % 6 {
|
||||
case 0:
|
||||
s.RecordRequest(tokenKey, j%2 == 0, time.Duration(j)*time.Millisecond)
|
||||
case 1:
|
||||
s.SetQuotaRemaining(tokenKey, float64(j%100)/100)
|
||||
case 2:
|
||||
s.GetMetrics(tokenKey)
|
||||
case 3:
|
||||
s.CalculateScore(tokenKey)
|
||||
case 4:
|
||||
s.SelectBestToken([]string{tokenKey, "token_x", "token_y"})
|
||||
case 5:
|
||||
if j%20 == 0 {
|
||||
s.ResetMetrics(tokenKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAvgLatencyCalculation(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 200*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 300*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.AvgLatency != 200 {
|
||||
t.Errorf("expected AvgLatency 200, got %f", m.AvgLatency)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLastUsedUpdated(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
before := time.Now()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.LastUsed.Before(before) {
|
||||
t.Error("expected LastUsed to be after test start time")
|
||||
}
|
||||
if m.LastUsed.After(time.Now()) {
|
||||
t.Error("expected LastUsed to be before or equal to now")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultQuotaForNewToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.QuotaRemaining != 1.0 {
|
||||
t.Errorf("expected default QuotaRemaining 1.0, got %f", m.QuotaRemaining)
|
||||
}
|
||||
}
|
||||
319
internal/auth/kiro/oauth.go
Normal file
319
internal/auth/kiro/oauth.go
Normal file
@@ -0,0 +1,319 @@
|
||||
// Package kiro provides OAuth2 authentication for Kiro using native Google login.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// Kiro auth endpoint
|
||||
kiroAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
|
||||
|
||||
// Default callback port
|
||||
defaultCallbackPort = 9876
|
||||
|
||||
// Auth timeout
|
||||
authTimeout = 10 * time.Minute
|
||||
)
|
||||
|
||||
// KiroTokenResponse represents the response from Kiro token endpoint.
|
||||
type KiroTokenResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ProfileArn string `json:"profileArn"`
|
||||
ExpiresIn int `json:"expiresIn"`
|
||||
}
|
||||
|
||||
// KiroOAuth handles the OAuth flow for Kiro authentication.
|
||||
type KiroOAuth struct {
|
||||
httpClient *http.Client
|
||||
cfg *config.Config
|
||||
machineID string
|
||||
kiroVersion string
|
||||
}
|
||||
|
||||
// NewKiroOAuth creates a new Kiro OAuth handler.
|
||||
func NewKiroOAuth(cfg *config.Config) *KiroOAuth {
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
fp := GlobalFingerprintManager().GetFingerprint("login")
|
||||
return &KiroOAuth{
|
||||
httpClient: client,
|
||||
cfg: cfg,
|
||||
machineID: fp.KiroHash,
|
||||
kiroVersion: fp.KiroVersion,
|
||||
}
|
||||
}
|
||||
|
||||
// generateCodeVerifier generates a random code verifier for PKCE.
|
||||
func generateCodeVerifier() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// generateCodeChallenge generates the code challenge from verifier.
|
||||
func generateCodeChallenge(verifier string) string {
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
return base64.RawURLEncoding.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// generateState generates a random state parameter.
|
||||
func generateState() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// AuthResult contains the authorization code and state from callback.
|
||||
type AuthResult struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
}
|
||||
|
||||
// startCallbackServer starts a local HTTP server to receive the OAuth callback.
|
||||
func (o *KiroOAuth) startCallbackServer(ctx context.Context, expectedState string) (string, <-chan AuthResult, error) {
|
||||
// Try to find an available port - use localhost like Kiro does
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", defaultCallbackPort))
|
||||
if err != nil {
|
||||
// Try with dynamic port (RFC 8252 allows dynamic ports for native apps)
|
||||
log.Warnf("kiro oauth: default port %d is busy, falling back to dynamic port", defaultCallbackPort)
|
||||
listener, err = net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
// Use http scheme for local callback server
|
||||
redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port)
|
||||
resultChan := make(chan AuthResult, 1)
|
||||
|
||||
server := &http.Server{
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
errParam := r.URL.Query().Get("error")
|
||||
|
||||
if errParam != "" {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(w, `<html><body><h1>Login Failed</h1><p>%s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
|
||||
resultChan <- AuthResult{Error: errParam}
|
||||
return
|
||||
}
|
||||
|
||||
if state != expectedState {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(w, `<html><body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
|
||||
resultChan <- AuthResult{Error: "state mismatch"}
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprint(w, `<html><body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p></body></html>`)
|
||||
resultChan <- AuthResult{Code: code, State: state}
|
||||
})
|
||||
|
||||
server.Handler = mux
|
||||
|
||||
go func() {
|
||||
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||
log.Debugf("callback server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(authTimeout):
|
||||
case <-resultChan:
|
||||
}
|
||||
_ = server.Shutdown(context.Background())
|
||||
}()
|
||||
|
||||
return redirectURI, resultChan, nil
|
||||
}
|
||||
|
||||
// LoginWithBuilderID performs OAuth login with AWS Builder ID using device code flow.
|
||||
func (o *KiroOAuth) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) {
|
||||
ssoClient := NewSSOOIDCClient(o.cfg)
|
||||
return ssoClient.LoginWithBuilderID(ctx)
|
||||
}
|
||||
|
||||
// LoginWithBuilderIDAuthCode performs OAuth login with AWS Builder ID using authorization code flow.
|
||||
// This provides a better UX than device code flow as it uses automatic browser callback.
|
||||
func (o *KiroOAuth) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTokenData, error) {
|
||||
ssoClient := NewSSOOIDCClient(o.cfg)
|
||||
return ssoClient.LoginWithBuilderIDAuthCode(ctx)
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges the authorization code for tokens.
|
||||
func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier, redirectURI string) (*KiroTokenData, error) {
|
||||
payload := map[string]string{
|
||||
"code": code,
|
||||
"code_verifier": codeVerifier,
|
||||
"redirect_uri": redirectURI,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
tokenURL := kiroAuthEndpoint + "/oauth/token"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID))
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var tokenResp KiroTokenResponse
|
||||
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
// Validate ExpiresIn - use default 1 hour if invalid
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: tokenResp.ProfileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: "", // Caller should preserve original provider
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an expired access token.
|
||||
// Uses KiroIDE-style User-Agent to match official Kiro IDE behavior.
|
||||
func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) {
|
||||
return o.RefreshTokenWithFingerprint(ctx, refreshToken, "")
|
||||
}
|
||||
|
||||
// RefreshTokenWithFingerprint refreshes an expired access token with a specific fingerprint.
|
||||
// tokenKey is used to generate a consistent fingerprint for the token.
|
||||
func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToken, tokenKey string) (*KiroTokenData, error) {
|
||||
payload := map[string]string{
|
||||
"refreshToken": refreshToken,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
refreshURL := kiroAuthEndpoint + "/refreshToken"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID))
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var tokenResp KiroTokenResponse
|
||||
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
// Validate ExpiresIn - use default 1 hour if invalid
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: tokenResp.ProfileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: "", // Caller should preserve original provider
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoginWithGoogle performs OAuth login with Google using Kiro's social auth.
|
||||
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||
func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
|
||||
socialClient := NewSocialAuthClient(o.cfg)
|
||||
return socialClient.LoginWithGoogle(ctx)
|
||||
}
|
||||
|
||||
// LoginWithGitHub performs OAuth login with GitHub using Kiro's social auth.
|
||||
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||
func (o *KiroOAuth) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) {
|
||||
socialClient := NewSocialAuthClient(o.cfg)
|
||||
return socialClient.LoginWithGitHub(ctx)
|
||||
}
|
||||
975
internal/auth/kiro/oauth_web.go
Normal file
975
internal/auth/kiro/oauth_web.go
Normal file
@@ -0,0 +1,975 @@
|
||||
// Package kiro provides OAuth Web authentication for Kiro.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultSessionExpiry = 10 * time.Minute
|
||||
pollIntervalSeconds = 5
|
||||
)
|
||||
|
||||
type authSessionStatus string
|
||||
|
||||
const (
|
||||
statusPending authSessionStatus = "pending"
|
||||
statusSuccess authSessionStatus = "success"
|
||||
statusFailed authSessionStatus = "failed"
|
||||
)
|
||||
|
||||
type webAuthSession struct {
|
||||
stateID string
|
||||
deviceCode string
|
||||
userCode string
|
||||
authURL string
|
||||
verificationURI string
|
||||
expiresIn int
|
||||
interval int
|
||||
status authSessionStatus
|
||||
startedAt time.Time
|
||||
completedAt time.Time
|
||||
expiresAt time.Time
|
||||
error string
|
||||
tokenData *KiroTokenData
|
||||
ssoClient *SSOOIDCClient
|
||||
clientID string
|
||||
clientSecret string
|
||||
region string
|
||||
cancelFunc context.CancelFunc
|
||||
authMethod string // "google", "github", "builder-id", "idc"
|
||||
startURL string // Used for IDC
|
||||
codeVerifier string // Used for social auth PKCE
|
||||
codeChallenge string // Used for social auth PKCE
|
||||
}
|
||||
|
||||
type OAuthWebHandler struct {
|
||||
cfg *config.Config
|
||||
sessions map[string]*webAuthSession
|
||||
mu sync.RWMutex
|
||||
onTokenObtained func(*KiroTokenData)
|
||||
}
|
||||
|
||||
func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
|
||||
return &OAuthWebHandler{
|
||||
cfg: cfg,
|
||||
sessions: make(map[string]*webAuthSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) SetTokenCallback(callback func(*KiroTokenData)) {
|
||||
h.onTokenObtained = callback
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) RegisterRoutes(router gin.IRouter) {
|
||||
oauth := router.Group("/v0/oauth/kiro")
|
||||
{
|
||||
oauth.GET("", h.handleSelect)
|
||||
oauth.GET("/start", h.handleStart)
|
||||
oauth.GET("/callback", h.handleCallback)
|
||||
oauth.GET("/social/callback", h.handleSocialCallback)
|
||||
oauth.GET("/status", h.handleStatus)
|
||||
oauth.POST("/import", h.handleImportToken)
|
||||
oauth.POST("/refresh", h.handleManualRefresh)
|
||||
}
|
||||
}
|
||||
|
||||
func generateStateID() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleSelect(c *gin.Context) {
|
||||
h.renderSelectPage(c)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleStart(c *gin.Context) {
|
||||
method := c.Query("method")
|
||||
|
||||
if method == "" {
|
||||
c.Redirect(http.StatusFound, "/v0/oauth/kiro")
|
||||
return
|
||||
}
|
||||
|
||||
switch method {
|
||||
case "google", "github":
|
||||
// Google/GitHub social login is not supported for third-party apps
|
||||
// due to AWS Cognito redirect_uri restrictions
|
||||
h.renderError(c, "Google/GitHub login is not available for third-party applications. Please use AWS Builder ID or import your token from Kiro IDE.")
|
||||
case "builder-id":
|
||||
h.startBuilderIDAuth(c)
|
||||
case "idc":
|
||||
h.startIDCAuth(c)
|
||||
default:
|
||||
h.renderError(c, fmt.Sprintf("Unknown authentication method: %s", method))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) startSocialAuth(c *gin.Context, method string) {
|
||||
stateID, err := generateStateID()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
codeVerifier, codeChallenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate PKCE parameters")
|
||||
return
|
||||
}
|
||||
|
||||
socialClient := NewSocialAuthClient(h.cfg)
|
||||
|
||||
var provider string
|
||||
if method == "google" {
|
||||
provider = string(ProviderGoogle)
|
||||
} else {
|
||||
provider = string(ProviderGitHub)
|
||||
}
|
||||
|
||||
redirectURI := h.getSocialCallbackURL(c)
|
||||
authURL := socialClient.buildLoginURL(provider, redirectURI, codeChallenge, stateID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
|
||||
session := &webAuthSession{
|
||||
stateID: stateID,
|
||||
authMethod: method,
|
||||
authURL: authURL,
|
||||
status: statusPending,
|
||||
startedAt: time.Now(),
|
||||
expiresIn: 600,
|
||||
codeVerifier: codeVerifier,
|
||||
codeChallenge: codeChallenge,
|
||||
region: "us-east-1",
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[stateID] = session
|
||||
h.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
h.mu.Lock()
|
||||
if session.status == statusPending {
|
||||
session.status = statusFailed
|
||||
session.error = "Authentication timed out"
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}()
|
||||
|
||||
c.Redirect(http.StatusFound, authURL)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) getSocialCallbackURL(c *gin.Context) string {
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
return fmt.Sprintf("%s://%s/v0/oauth/kiro/social/callback", scheme, c.Request.Host)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) startBuilderIDAuth(c *gin.Context) {
|
||||
stateID, err := generateStateID()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
region := defaultIDCRegion
|
||||
startURL := builderIDStartURL
|
||||
|
||||
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||
|
||||
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to register client: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
|
||||
c.Request.Context(),
|
||||
regResp.ClientID,
|
||||
regResp.ClientSecret,
|
||||
startURL,
|
||||
region,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
|
||||
|
||||
session := &webAuthSession{
|
||||
stateID: stateID,
|
||||
deviceCode: authResp.DeviceCode,
|
||||
userCode: authResp.UserCode,
|
||||
authURL: authResp.VerificationURIComplete,
|
||||
verificationURI: authResp.VerificationURI,
|
||||
expiresIn: authResp.ExpiresIn,
|
||||
interval: authResp.Interval,
|
||||
status: statusPending,
|
||||
startedAt: time.Now(),
|
||||
ssoClient: ssoClient,
|
||||
clientID: regResp.ClientID,
|
||||
clientSecret: regResp.ClientSecret,
|
||||
region: region,
|
||||
authMethod: "builder-id",
|
||||
startURL: startURL,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[stateID] = session
|
||||
h.mu.Unlock()
|
||||
|
||||
go h.pollForToken(ctx, session)
|
||||
|
||||
h.renderStartPage(c, session)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) startIDCAuth(c *gin.Context) {
|
||||
startURL := c.Query("startUrl")
|
||||
region := c.Query("region")
|
||||
|
||||
if startURL == "" {
|
||||
h.renderError(c, "Missing startUrl parameter for IDC authentication")
|
||||
return
|
||||
}
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
|
||||
stateID, err := generateStateID()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||
|
||||
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to register client: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
|
||||
c.Request.Context(),
|
||||
regResp.ClientID,
|
||||
regResp.ClientSecret,
|
||||
startURL,
|
||||
region,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
|
||||
|
||||
session := &webAuthSession{
|
||||
stateID: stateID,
|
||||
deviceCode: authResp.DeviceCode,
|
||||
userCode: authResp.UserCode,
|
||||
authURL: authResp.VerificationURIComplete,
|
||||
verificationURI: authResp.VerificationURI,
|
||||
expiresIn: authResp.ExpiresIn,
|
||||
interval: authResp.Interval,
|
||||
status: statusPending,
|
||||
startedAt: time.Now(),
|
||||
ssoClient: ssoClient,
|
||||
clientID: regResp.ClientID,
|
||||
clientSecret: regResp.ClientSecret,
|
||||
region: region,
|
||||
authMethod: "idc",
|
||||
startURL: startURL,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[stateID] = session
|
||||
h.mu.Unlock()
|
||||
|
||||
go h.pollForToken(ctx, session)
|
||||
|
||||
h.renderStartPage(c, session)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSession) {
|
||||
defer session.cancelFunc()
|
||||
|
||||
interval := time.Duration(session.interval) * time.Second
|
||||
if interval < time.Duration(pollIntervalSeconds)*time.Second {
|
||||
interval = time.Duration(pollIntervalSeconds) * time.Second
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
h.mu.Lock()
|
||||
if session.status == statusPending {
|
||||
session.status = statusFailed
|
||||
session.error = "Authentication timed out"
|
||||
}
|
||||
h.mu.Unlock()
|
||||
return
|
||||
case <-ticker.C:
|
||||
tokenResp, err := h.ssoClient(session).CreateTokenWithRegion(
|
||||
ctx,
|
||||
session.clientID,
|
||||
session.clientSecret,
|
||||
session.deviceCode,
|
||||
session.region,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
if errStr == ErrAuthorizationPending.Error() {
|
||||
continue
|
||||
}
|
||||
if errStr == ErrSlowDown.Error() {
|
||||
interval += 5 * time.Second
|
||||
ticker.Reset(interval)
|
||||
continue
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusFailed
|
||||
session.error = errStr
|
||||
session.completedAt = time.Now()
|
||||
h.mu.Unlock()
|
||||
|
||||
log.Errorf("OAuth Web: token polling failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||
|
||||
// Fetch profileArn for IDC
|
||||
var profileArn string
|
||||
if session.authMethod == "idc" {
|
||||
profileArn = session.ssoClient.FetchProfileArn(ctx, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken)
|
||||
}
|
||||
|
||||
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken)
|
||||
|
||||
tokenData := &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: profileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: session.authMethod,
|
||||
Provider: "AWS",
|
||||
ClientID: session.clientID,
|
||||
ClientSecret: session.clientSecret,
|
||||
Email: email,
|
||||
Region: session.region,
|
||||
StartURL: session.startURL,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusSuccess
|
||||
session.completedAt = time.Now()
|
||||
session.expiresAt = expiresAt
|
||||
session.tokenData = tokenData
|
||||
h.mu.Unlock()
|
||||
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
|
||||
// Save token to file
|
||||
h.saveTokenToFile(tokenData)
|
||||
|
||||
log.Infof("OAuth Web: authentication successful for %s", email)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// saveTokenToFile saves the token data to the auth directory
|
||||
func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
|
||||
// Get auth directory from config or use default
|
||||
authDir := ""
|
||||
if h.cfg != nil && h.cfg.AuthDir != "" {
|
||||
var err error
|
||||
authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default location
|
||||
if authDir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to get home directory: %v", err)
|
||||
return
|
||||
}
|
||||
authDir = filepath.Join(home, ".cli-proxy-api")
|
||||
}
|
||||
|
||||
// Create directory if not exists
|
||||
if err := os.MkdirAll(authDir, 0700); err != nil {
|
||||
log.Errorf("OAuth Web: failed to create auth directory: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate filename using the unified function
|
||||
fileName := GenerateTokenFileName(tokenData)
|
||||
|
||||
authFilePath := filepath.Join(authDir, fileName)
|
||||
|
||||
// Convert to storage format and save
|
||||
storage := &KiroTokenStorage{
|
||||
Type: "kiro",
|
||||
AccessToken: tokenData.AccessToken,
|
||||
RefreshToken: tokenData.RefreshToken,
|
||||
ProfileArn: tokenData.ProfileArn,
|
||||
ExpiresAt: tokenData.ExpiresAt,
|
||||
AuthMethod: tokenData.AuthMethod,
|
||||
Provider: tokenData.Provider,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
ClientID: tokenData.ClientID,
|
||||
ClientSecret: tokenData.ClientSecret,
|
||||
Region: tokenData.Region,
|
||||
StartURL: tokenData.StartURL,
|
||||
Email: tokenData.Email,
|
||||
}
|
||||
|
||||
if err := storage.SaveTokenToFile(authFilePath); err != nil {
|
||||
log.Errorf("OAuth Web: failed to save token to file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("OAuth Web: token saved to %s", authFilePath)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) ssoClient(session *webAuthSession) *SSOOIDCClient {
|
||||
return session.ssoClient
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleCallback(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
errParam := c.Query("error")
|
||||
|
||||
if errParam != "" {
|
||||
h.renderError(c, errParam)
|
||||
return
|
||||
}
|
||||
|
||||
if stateID == "" {
|
||||
h.renderError(c, "Missing state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
h.renderError(c, "Invalid or expired session")
|
||||
return
|
||||
}
|
||||
|
||||
if session.status == statusSuccess {
|
||||
h.renderSuccess(c, session)
|
||||
} else if session.status == statusFailed {
|
||||
h.renderError(c, session.error)
|
||||
} else {
|
||||
c.Redirect(http.StatusFound, "/v0/oauth/kiro/start")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleSocialCallback(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
code := c.Query("code")
|
||||
errParam := c.Query("error")
|
||||
|
||||
if errParam != "" {
|
||||
h.renderError(c, errParam)
|
||||
return
|
||||
}
|
||||
|
||||
if stateID == "" {
|
||||
h.renderError(c, "Missing state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
if code == "" {
|
||||
h.renderError(c, "Missing authorization code")
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
h.renderError(c, "Invalid or expired session")
|
||||
return
|
||||
}
|
||||
|
||||
if session.authMethod != "google" && session.authMethod != "github" {
|
||||
h.renderError(c, "Invalid session type for social callback")
|
||||
return
|
||||
}
|
||||
|
||||
socialClient := NewSocialAuthClient(h.cfg)
|
||||
redirectURI := h.getSocialCallbackURL(c)
|
||||
|
||||
tokenReq := &CreateTokenRequest{
|
||||
Code: code,
|
||||
CodeVerifier: session.codeVerifier,
|
||||
RedirectURI: redirectURI,
|
||||
}
|
||||
|
||||
tokenResp, err := socialClient.CreateToken(c.Request.Context(), tokenReq)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: social token exchange failed: %v", err)
|
||||
h.mu.Lock()
|
||||
session.status = statusFailed
|
||||
session.error = fmt.Sprintf("Token exchange failed: %v", err)
|
||||
session.completedAt = time.Now()
|
||||
h.mu.Unlock()
|
||||
h.renderError(c, session.error)
|
||||
return
|
||||
}
|
||||
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
|
||||
var provider string
|
||||
if session.authMethod == "google" {
|
||||
provider = string(ProviderGoogle)
|
||||
} else {
|
||||
provider = string(ProviderGitHub)
|
||||
}
|
||||
|
||||
tokenData := &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: tokenResp.ProfileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: session.authMethod,
|
||||
Provider: provider,
|
||||
Email: email,
|
||||
Region: "us-east-1",
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusSuccess
|
||||
session.completedAt = time.Now()
|
||||
session.expiresAt = expiresAt
|
||||
session.tokenData = tokenData
|
||||
h.mu.Unlock()
|
||||
|
||||
if session.cancelFunc != nil {
|
||||
session.cancelFunc()
|
||||
}
|
||||
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
|
||||
// Save token to file
|
||||
h.saveTokenToFile(tokenData)
|
||||
|
||||
log.Infof("OAuth Web: social authentication successful for %s via %s", email, provider)
|
||||
h.renderSuccess(c, session)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleStatus(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
if stateID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing state parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"status": string(session.status),
|
||||
}
|
||||
|
||||
switch session.status {
|
||||
case statusPending:
|
||||
elapsed := time.Since(session.startedAt).Seconds()
|
||||
remaining := float64(session.expiresIn) - elapsed
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
response["remaining_seconds"] = int(remaining)
|
||||
case statusSuccess:
|
||||
response["completed_at"] = session.completedAt.Format(time.RFC3339)
|
||||
response["expires_at"] = session.expiresAt.Format(time.RFC3339)
|
||||
case statusFailed:
|
||||
response["error"] = session.error
|
||||
response["failed_at"] = session.completedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderStartPage(c *gin.Context, session *webAuthSession) {
|
||||
tmpl, err := template.New("start").Parse(oauthWebStartPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"AuthURL": session.authURL,
|
||||
"UserCode": session.userCode,
|
||||
"ExpiresIn": session.expiresIn,
|
||||
"StateID": session.stateID,
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderSelectPage(c *gin.Context) {
|
||||
tmpl, err := template.New("select").Parse(oauthWebSelectPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse select template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, nil); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render select template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderError(c *gin.Context, errMsg string) {
|
||||
tmpl, err := template.New("error").Parse(oauthWebErrorPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse error template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"Error": errMsg,
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.Status(http.StatusBadRequest)
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render error template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderSuccess(c *gin.Context, session *webAuthSession) {
|
||||
tmpl, err := template.New("success").Parse(oauthWebSuccessPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse success template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"ExpiresAt": session.expiresAt.Format(time.RFC3339),
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render success template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) CleanupExpiredSessions() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for id, session := range h.sessions {
|
||||
if session.status != statusPending && now.Sub(session.completedAt) > 30*time.Minute {
|
||||
delete(h.sessions, id)
|
||||
} else if session.status == statusPending && now.Sub(session.startedAt) > defaultSessionExpiry {
|
||||
session.cancelFunc()
|
||||
delete(h.sessions, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) GetSession(stateID string) (*webAuthSession, bool) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
session, exists := h.sessions[stateID]
|
||||
return session, exists
|
||||
}
|
||||
|
||||
// ImportTokenRequest represents the request body for token import
|
||||
type ImportTokenRequest struct {
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// handleImportToken handles manual refresh token import from Kiro IDE
|
||||
func (h *OAuthWebHandler) handleImportToken(c *gin.Context) {
|
||||
var req ImportTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": "Invalid request body",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
refreshToken := strings.TrimSpace(req.RefreshToken)
|
||||
if refreshToken == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": "Refresh token is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate token format
|
||||
if !strings.HasPrefix(refreshToken, "aorAAAAAG") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": "Invalid token format. Token should start with aorAAAAAG...",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create social auth client to refresh and validate the token
|
||||
socialClient := NewSocialAuthClient(h.cfg)
|
||||
|
||||
// Refresh the token to validate it and get access token
|
||||
tokenData, err := socialClient.RefreshSocialToken(c.Request.Context(), refreshToken)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: token refresh failed during import: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("Token validation failed: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Set the original refresh token (the refreshed one might be empty)
|
||||
if tokenData.RefreshToken == "" {
|
||||
tokenData.RefreshToken = refreshToken
|
||||
}
|
||||
tokenData.AuthMethod = "social"
|
||||
tokenData.Provider = "imported"
|
||||
|
||||
// Notify callback if set
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
|
||||
// Save token to file
|
||||
h.saveTokenToFile(tokenData)
|
||||
|
||||
// Generate filename for response using the unified function
|
||||
fileName := GenerateTokenFileName(tokenData)
|
||||
|
||||
log.Infof("OAuth Web: token imported successfully")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Token imported successfully",
|
||||
"fileName": fileName,
|
||||
})
|
||||
}
|
||||
|
||||
// handleManualRefresh handles manual token refresh requests from the web UI.
|
||||
// This allows users to trigger a token refresh when needed, without waiting
|
||||
// for the automatic 30-second check and 20-minute-before-expiry refresh cycle.
|
||||
// Uses the same refresh logic as kiro_executor.Refresh for consistency.
|
||||
func (h *OAuthWebHandler) handleManualRefresh(c *gin.Context) {
|
||||
authDir := ""
|
||||
if h.cfg != nil && h.cfg.AuthDir != "" {
|
||||
var err error
|
||||
authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if authDir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"error": "Failed to get home directory",
|
||||
})
|
||||
return
|
||||
}
|
||||
authDir = filepath.Join(home, ".cli-proxy-api")
|
||||
}
|
||||
|
||||
// Find all kiro token files in the auth directory
|
||||
files, err := os.ReadDir(authDir)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("Failed to read auth directory: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var refreshedCount int
|
||||
var errors []string
|
||||
|
||||
for _, file := range files {
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := file.Name()
|
||||
if !strings.HasPrefix(name, "kiro-") || !strings.HasSuffix(name, ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
filePath := filepath.Join(authDir, name)
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: read error - %v", name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
var storage KiroTokenStorage
|
||||
if err := json.Unmarshal(data, &storage); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: parse error - %v", name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if storage.RefreshToken == "" {
|
||||
errors = append(errors, fmt.Sprintf("%s: no refresh token", name))
|
||||
continue
|
||||
}
|
||||
|
||||
// Refresh token using the same logic as kiro_executor.Refresh
|
||||
tokenData, err := h.refreshTokenData(c.Request.Context(), &storage)
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: refresh failed - %v", name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Update storage with new token data
|
||||
storage.AccessToken = tokenData.AccessToken
|
||||
if tokenData.RefreshToken != "" {
|
||||
storage.RefreshToken = tokenData.RefreshToken
|
||||
}
|
||||
storage.ExpiresAt = tokenData.ExpiresAt
|
||||
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||
if tokenData.ProfileArn != "" {
|
||||
storage.ProfileArn = tokenData.ProfileArn
|
||||
}
|
||||
|
||||
// Write updated token back to file
|
||||
updatedData, err := json.MarshalIndent(storage, "", " ")
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: marshal error - %v", name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
tmpFile := filePath + ".tmp"
|
||||
if err := os.WriteFile(tmpFile, updatedData, 0600); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: write error - %v", name, err))
|
||||
continue
|
||||
}
|
||||
if err := os.Rename(tmpFile, filePath); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: rename error - %v", name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("OAuth Web: manually refreshed token in %s, expires at %s", name, tokenData.ExpiresAt)
|
||||
refreshedCount++
|
||||
|
||||
// Notify callback if set
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
}
|
||||
|
||||
if refreshedCount == 0 && len(errors) > 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("All refresh attempts failed: %v", errors),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("Refreshed %d token(s)", refreshedCount),
|
||||
"refreshedCount": refreshedCount,
|
||||
}
|
||||
if len(errors) > 0 {
|
||||
response["warnings"] = errors
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// refreshTokenData refreshes a token using the appropriate method based on auth type.
|
||||
// This mirrors the logic in kiro_executor.Refresh for consistency.
|
||||
func (h *OAuthWebHandler) refreshTokenData(ctx context.Context, storage *KiroTokenStorage) (*KiroTokenData, error) {
|
||||
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||
|
||||
switch {
|
||||
case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "idc" && storage.Region != "":
|
||||
// IDC refresh with region-specific endpoint
|
||||
log.Debugf("OAuth Web: using SSO OIDC refresh for IDC (region=%s)", storage.Region)
|
||||
return ssoClient.RefreshTokenWithRegion(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken, storage.Region, storage.StartURL)
|
||||
|
||||
case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "builder-id":
|
||||
// Builder ID refresh with default endpoint
|
||||
log.Debugf("OAuth Web: using SSO OIDC refresh for AWS Builder ID")
|
||||
return ssoClient.RefreshToken(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken)
|
||||
|
||||
default:
|
||||
// Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub)
|
||||
log.Debugf("OAuth Web: using Kiro OAuth refresh endpoint")
|
||||
oauth := NewKiroOAuth(h.cfg)
|
||||
return oauth.RefreshToken(ctx, storage.RefreshToken)
|
||||
}
|
||||
}
|
||||
779
internal/auth/kiro/oauth_web_templates.go
Normal file
779
internal/auth/kiro/oauth_web_templates.go
Normal file
@@ -0,0 +1,779 @@
|
||||
// Package kiro provides OAuth Web authentication templates.
|
||||
package kiro
|
||||
|
||||
const (
|
||||
oauthWebStartPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>AWS SSO Authentication</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
.container {
|
||||
max-width: 500px;
|
||||
width: 100%;
|
||||
background: #fff;
|
||||
padding: 40px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
|
||||
}
|
||||
h1 {
|
||||
margin: 0 0 10px;
|
||||
color: #333;
|
||||
font-size: 24px;
|
||||
text-align: center;
|
||||
}
|
||||
.subtitle {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.step {
|
||||
background: #f8f9fa;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.step-title {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.step-number {
|
||||
width: 28px;
|
||||
height: 28px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 14px;
|
||||
margin-right: 12px;
|
||||
}
|
||||
.user-code {
|
||||
background: #e7f3ff;
|
||||
border: 2px dashed #2196F3;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
margin-top: 10px;
|
||||
}
|
||||
.user-code-label {
|
||||
font-size: 12px;
|
||||
color: #666;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 1px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.user-code-value {
|
||||
font-size: 32px;
|
||||
font-weight: bold;
|
||||
font-family: monospace;
|
||||
color: #2196F3;
|
||||
letter-spacing: 4px;
|
||||
}
|
||||
.auth-btn {
|
||||
display: block;
|
||||
width: 100%;
|
||||
padding: 15px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
text-align: center;
|
||||
text-decoration: none;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
transition: all 0.3s;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
margin-top: 20px;
|
||||
}
|
||||
.auth-btn:hover {
|
||||
background: #5568d3;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
.status {
|
||||
margin-top: 30px;
|
||||
padding: 20px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.status-pending { border-left: 4px solid #ffc107; }
|
||||
.status-success { border-left: 4px solid #28a745; }
|
||||
.status-failed { border-left: 4px solid #dc3545; }
|
||||
.spinner {
|
||||
border: 3px solid #f3f3f3;
|
||||
border-top: 3px solid #667eea;
|
||||
border-radius: 50%;
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
animation: spin 1s linear infinite;
|
||||
margin: 0 auto 15px;
|
||||
}
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
.timer {
|
||||
font-size: 24px;
|
||||
font-weight: bold;
|
||||
color: #667eea;
|
||||
margin: 10px 0;
|
||||
}
|
||||
.timer.warning { color: #ffc107; }
|
||||
.timer.danger { color: #dc3545; }
|
||||
.status-message { color: #666; line-height: 1.6; }
|
||||
.success-icon, .error-icon { font-size: 48px; margin-bottom: 15px; }
|
||||
.info-box {
|
||||
background: #e7f3ff;
|
||||
border-left: 4px solid #2196F3;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
color: #666;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 AWS SSO Authentication</h1>
|
||||
<p class="subtitle">Follow the steps below to complete authentication</p>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-title">
|
||||
<span class="step-number">1</span>
|
||||
Click the button below to open the authorization page
|
||||
</div>
|
||||
<a href="{{.AuthURL}}" target="_blank" class="auth-btn" id="authBtn">
|
||||
🚀 Open Authorization Page
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-title">
|
||||
<span class="step-number">2</span>
|
||||
Enter the verification code below
|
||||
</div>
|
||||
<div class="user-code">
|
||||
<div class="user-code-label">Verification Code</div>
|
||||
<div class="user-code-value">{{.UserCode}}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-title">
|
||||
<span class="step-number">3</span>
|
||||
Complete AWS SSO login
|
||||
</div>
|
||||
<p style="color: #666; font-size: 14px; margin-top: 10px;">
|
||||
Use your AWS SSO account to login and authorize
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div class="status status-pending" id="statusBox">
|
||||
<div class="spinner" id="spinner"></div>
|
||||
<div class="timer" id="timer">{{.ExpiresIn}}s</div>
|
||||
<div class="status-message" id="statusMessage">
|
||||
Waiting for authorization...
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="info-box">
|
||||
💡 <strong>Tip:</strong> The authorization page will open in a new tab. This page will automatically update once authorization is complete.
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let pollInterval;
|
||||
let timerInterval;
|
||||
let remainingSeconds = {{.ExpiresIn}};
|
||||
const stateID = "{{.StateID}}";
|
||||
|
||||
setTimeout(() => {
|
||||
document.getElementById('authBtn').click();
|
||||
}, 500);
|
||||
|
||||
function pollStatus() {
|
||||
fetch('/v0/oauth/kiro/status?state=' + stateID)
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log('Status:', data);
|
||||
if (data.status === 'success') {
|
||||
clearInterval(pollInterval);
|
||||
clearInterval(timerInterval);
|
||||
showSuccess(data);
|
||||
} else if (data.status === 'failed') {
|
||||
clearInterval(pollInterval);
|
||||
clearInterval(timerInterval);
|
||||
showError(data);
|
||||
} else {
|
||||
remainingSeconds = data.remaining_seconds || 0;
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Poll error:', error);
|
||||
});
|
||||
}
|
||||
|
||||
function updateTimer() {
|
||||
const timerEl = document.getElementById('timer');
|
||||
const minutes = Math.floor(remainingSeconds / 60);
|
||||
const seconds = remainingSeconds % 60;
|
||||
timerEl.textContent = minutes + ':' + seconds.toString().padStart(2, '0');
|
||||
|
||||
if (remainingSeconds < 60) {
|
||||
timerEl.className = 'timer danger';
|
||||
} else if (remainingSeconds < 180) {
|
||||
timerEl.className = 'timer warning';
|
||||
} else {
|
||||
timerEl.className = 'timer';
|
||||
}
|
||||
|
||||
remainingSeconds--;
|
||||
|
||||
if (remainingSeconds < 0) {
|
||||
clearInterval(timerInterval);
|
||||
clearInterval(pollInterval);
|
||||
showError({ error: 'Authentication timed out. Please refresh and try again.' });
|
||||
}
|
||||
}
|
||||
|
||||
function showSuccess(data) {
|
||||
const statusBox = document.getElementById('statusBox');
|
||||
statusBox.className = 'status status-success';
|
||||
statusBox.innerHTML = '<div class="success-icon">✅</div>' +
|
||||
'<div class="status-message">' +
|
||||
'<strong>Authentication Successful!</strong><br>' +
|
||||
'Token expires: ' + new Date(data.expires_at).toLocaleString() +
|
||||
'</div>';
|
||||
}
|
||||
|
||||
function showError(data) {
|
||||
const statusBox = document.getElementById('statusBox');
|
||||
statusBox.className = 'status status-failed';
|
||||
statusBox.innerHTML = '<div class="error-icon">❌</div>' +
|
||||
'<div class="status-message">' +
|
||||
'<strong>Authentication Failed</strong><br>' +
|
||||
(data.error || 'Unknown error') +
|
||||
'</div>' +
|
||||
'<button class="auth-btn" onclick="location.reload()" style="margin-top: 15px;">' +
|
||||
'🔄 Retry' +
|
||||
'</button>';
|
||||
}
|
||||
|
||||
pollInterval = setInterval(pollStatus, 3000);
|
||||
timerInterval = setInterval(updateTimer, 1000);
|
||||
pollStatus();
|
||||
</script>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
oauthWebErrorPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authentication Failed</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
max-width: 600px;
|
||||
margin: 50px auto;
|
||||
padding: 20px;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.error {
|
||||
background: #fff;
|
||||
padding: 30px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
border-left: 4px solid #dc3545;
|
||||
}
|
||||
h1 { color: #dc3545; margin-top: 0; }
|
||||
.error-message { color: #666; line-height: 1.6; }
|
||||
.retry-btn {
|
||||
display: inline-block;
|
||||
margin-top: 20px;
|
||||
padding: 10px 20px;
|
||||
background: #007bff;
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.retry-btn:hover { background: #0056b3; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="error">
|
||||
<h1>❌ Authentication Failed</h1>
|
||||
<div class="error-message">
|
||||
<p><strong>Error:</strong></p>
|
||||
<p>{{.Error}}</p>
|
||||
</div>
|
||||
<a href="/v0/oauth/kiro/start" class="retry-btn">🔄 Retry</a>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
oauthWebSuccessPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authentication Successful</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
max-width: 600px;
|
||||
margin: 50px auto;
|
||||
padding: 20px;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.success {
|
||||
background: #fff;
|
||||
padding: 30px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
border-left: 4px solid #28a745;
|
||||
text-align: center;
|
||||
}
|
||||
h1 { color: #28a745; margin-top: 0; }
|
||||
.success-message { color: #666; line-height: 1.6; }
|
||||
.icon { font-size: 48px; margin-bottom: 15px; }
|
||||
.expires { font-size: 14px; color: #999; margin-top: 15px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="success">
|
||||
<div class="icon">✅</div>
|
||||
<h1>Authentication Successful!</h1>
|
||||
<div class="success-message">
|
||||
<p>You can close this window.</p>
|
||||
</div>
|
||||
<div class="expires">Token expires: {{.ExpiresAt}}</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
oauthWebSelectPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Select Authentication Method</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
.container {
|
||||
max-width: 500px;
|
||||
width: 100%;
|
||||
background: #fff;
|
||||
padding: 40px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
|
||||
}
|
||||
h1 {
|
||||
margin: 0 0 10px;
|
||||
color: #333;
|
||||
font-size: 24px;
|
||||
text-align: center;
|
||||
}
|
||||
.subtitle {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.auth-methods {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 15px;
|
||||
}
|
||||
.auth-btn {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
width: 100%;
|
||||
padding: 15px 20px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
transition: all 0.3s;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
}
|
||||
.auth-btn:hover {
|
||||
background: #5568d3;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
.auth-btn .icon {
|
||||
font-size: 24px;
|
||||
margin-right: 15px;
|
||||
width: 32px;
|
||||
text-align: center;
|
||||
}
|
||||
.auth-btn.google { background: #4285F4; }
|
||||
.auth-btn.google:hover { background: #3367D6; }
|
||||
.auth-btn.github { background: #24292e; }
|
||||
.auth-btn.github:hover { background: #1a1e22; }
|
||||
.auth-btn.aws { background: #FF9900; }
|
||||
.auth-btn.aws:hover { background: #E68A00; }
|
||||
.auth-btn.idc { background: #232F3E; }
|
||||
.auth-btn.idc:hover { background: #1a242f; }
|
||||
.idc-form {
|
||||
background: #f8f9fa;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
margin-top: 15px;
|
||||
display: none;
|
||||
}
|
||||
.idc-form.show {
|
||||
display: block;
|
||||
}
|
||||
.form-group {
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.form-group label {
|
||||
display: block;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin-bottom: 8px;
|
||||
font-size: 14px;
|
||||
}
|
||||
.form-group input {
|
||||
width: 100%;
|
||||
padding: 12px;
|
||||
border: 2px solid #e0e0e0;
|
||||
border-radius: 6px;
|
||||
font-size: 14px;
|
||||
transition: border-color 0.3s;
|
||||
}
|
||||
.form-group input:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
.form-group .hint {
|
||||
font-size: 12px;
|
||||
color: #999;
|
||||
margin-top: 5px;
|
||||
}
|
||||
.submit-btn {
|
||||
display: block;
|
||||
width: 100%;
|
||||
padding: 15px;
|
||||
background: #232F3E;
|
||||
color: white;
|
||||
text-align: center;
|
||||
text-decoration: none;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
transition: all 0.3s;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
}
|
||||
.submit-btn:hover {
|
||||
background: #1a242f;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(35, 47, 62, 0.4);
|
||||
}
|
||||
.divider {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin: 20px 0;
|
||||
}
|
||||
.divider::before,
|
||||
.divider::after {
|
||||
content: "";
|
||||
flex: 1;
|
||||
border-bottom: 1px solid #e0e0e0;
|
||||
}
|
||||
.divider span {
|
||||
padding: 0 15px;
|
||||
color: #999;
|
||||
font-size: 14px;
|
||||
}
|
||||
.info-box {
|
||||
background: #e7f3ff;
|
||||
border-left: 4px solid #2196F3;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
color: #666;
|
||||
}
|
||||
.warning-box {
|
||||
background: #fff3cd;
|
||||
border-left: 4px solid #ffc107;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
color: #856404;
|
||||
}
|
||||
.auth-btn.manual { background: #6c757d; }
|
||||
.auth-btn.manual:hover { background: #5a6268; }
|
||||
.auth-btn.refresh { background: #17a2b8; }
|
||||
.auth-btn.refresh:hover { background: #138496; }
|
||||
.auth-btn.refresh:disabled { background: #7fb3bd; cursor: not-allowed; }
|
||||
.manual-form {
|
||||
background: #f8f9fa;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
margin-top: 15px;
|
||||
display: none;
|
||||
}
|
||||
.manual-form.show {
|
||||
display: block;
|
||||
}
|
||||
.form-group textarea {
|
||||
width: 100%;
|
||||
padding: 12px;
|
||||
border: 2px solid #e0e0e0;
|
||||
border-radius: 6px;
|
||||
font-size: 14px;
|
||||
font-family: monospace;
|
||||
transition: border-color 0.3s;
|
||||
resize: vertical;
|
||||
min-height: 80px;
|
||||
}
|
||||
.form-group textarea:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
.status-message {
|
||||
padding: 15px;
|
||||
border-radius: 6px;
|
||||
margin-top: 15px;
|
||||
display: none;
|
||||
}
|
||||
.status-message.success {
|
||||
background: #d4edda;
|
||||
color: #155724;
|
||||
display: block;
|
||||
}
|
||||
.status-message.error {
|
||||
background: #f8d7da;
|
||||
color: #721c24;
|
||||
display: block;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 Select Authentication Method</h1>
|
||||
<p class="subtitle">Choose how you want to authenticate with Kiro</p>
|
||||
|
||||
<div class="auth-methods">
|
||||
<a href="/v0/oauth/kiro/start?method=builder-id" class="auth-btn aws">
|
||||
<span class="icon">🔶</span>
|
||||
AWS Builder ID (Recommended)
|
||||
</a>
|
||||
|
||||
<button type="button" class="auth-btn idc" onclick="toggleIdcForm()">
|
||||
<span class="icon">🏢</span>
|
||||
AWS Identity Center (IDC)
|
||||
</button>
|
||||
|
||||
<div class="divider"><span>or</span></div>
|
||||
|
||||
<button type="button" class="auth-btn manual" onclick="toggleManualForm()">
|
||||
<span class="icon">📋</span>
|
||||
Import RefreshToken from Kiro IDE
|
||||
</button>
|
||||
|
||||
<button type="button" class="auth-btn refresh" onclick="manualRefresh()" id="refreshBtn">
|
||||
<span class="icon">🔄</span>
|
||||
Manual Refresh All Tokens
|
||||
</button>
|
||||
|
||||
<div class="status-message" id="refreshStatus"></div>
|
||||
</div>
|
||||
|
||||
<div class="idc-form" id="idcForm">
|
||||
<form action="/v0/oauth/kiro/start" method="get">
|
||||
<input type="hidden" name="method" value="idc">
|
||||
|
||||
<div class="form-group">
|
||||
<label for="startUrl">Start URL</label>
|
||||
<input type="url" id="startUrl" name="startUrl" placeholder="https://your-org.awsapps.com/start" required>
|
||||
<div class="hint">Your AWS Identity Center Start URL</div>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="region">Region</label>
|
||||
<input type="text" id="region" name="region" value="us-east-1" placeholder="us-east-1">
|
||||
<div class="hint">AWS Region for your Identity Center</div>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="submit-btn">
|
||||
🚀 Continue with IDC
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<div class="manual-form" id="manualForm">
|
||||
<form id="importForm" onsubmit="submitImport(event)">
|
||||
<div class="form-group">
|
||||
<label for="refreshToken">Refresh Token</label>
|
||||
<textarea id="refreshToken" name="refreshToken" placeholder="Paste your refreshToken here (starts with aorAAAAAG...)" required></textarea>
|
||||
<div class="hint">Copy from Kiro IDE: ~/.kiro/kiro-auth-token.json → refreshToken field</div>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="submit-btn" id="importBtn">
|
||||
📥 Import Token
|
||||
</button>
|
||||
|
||||
<div class="status-message" id="importStatus"></div>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<div class="warning-box">
|
||||
⚠️ <strong>Note:</strong> Google and GitHub login are not available for third-party applications due to AWS Cognito restrictions. Please use AWS Builder ID or import your token from Kiro IDE.
|
||||
</div>
|
||||
|
||||
<div class="info-box">
|
||||
💡 <strong>How to get RefreshToken:</strong><br>
|
||||
1. Open Kiro IDE and login with Google/GitHub<br>
|
||||
2. Find the token file: <code>~/.kiro/kiro-auth-token.json</code><br>
|
||||
3. Copy the <code>refreshToken</code> value and paste it above
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
function toggleIdcForm() {
|
||||
const idcForm = document.getElementById('idcForm');
|
||||
const manualForm = document.getElementById('manualForm');
|
||||
manualForm.classList.remove('show');
|
||||
idcForm.classList.toggle('show');
|
||||
if (idcForm.classList.contains('show')) {
|
||||
document.getElementById('startUrl').focus();
|
||||
}
|
||||
}
|
||||
|
||||
function toggleManualForm() {
|
||||
const idcForm = document.getElementById('idcForm');
|
||||
const manualForm = document.getElementById('manualForm');
|
||||
idcForm.classList.remove('show');
|
||||
manualForm.classList.toggle('show');
|
||||
if (manualForm.classList.contains('show')) {
|
||||
document.getElementById('refreshToken').focus();
|
||||
}
|
||||
}
|
||||
|
||||
async function submitImport(event) {
|
||||
event.preventDefault();
|
||||
const refreshToken = document.getElementById('refreshToken').value.trim();
|
||||
const statusEl = document.getElementById('importStatus');
|
||||
const btn = document.getElementById('importBtn');
|
||||
|
||||
if (!refreshToken) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = 'Please enter a refresh token';
|
||||
return;
|
||||
}
|
||||
|
||||
if (!refreshToken.startsWith('aorAAAAAG')) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = 'Invalid token format. Token should start with aorAAAAAG...';
|
||||
return;
|
||||
}
|
||||
|
||||
btn.disabled = true;
|
||||
btn.textContent = '⏳ Importing...';
|
||||
statusEl.className = 'status-message';
|
||||
statusEl.style.display = 'none';
|
||||
|
||||
try {
|
||||
const response = await fetch('/v0/oauth/kiro/import', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ refreshToken: refreshToken })
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (response.ok && data.success) {
|
||||
statusEl.className = 'status-message success';
|
||||
statusEl.textContent = '✅ Token imported successfully! File: ' + (data.fileName || 'kiro-token.json');
|
||||
} else {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = '❌ ' + (data.error || data.message || 'Import failed');
|
||||
}
|
||||
} catch (error) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = '❌ Network error: ' + error.message;
|
||||
} finally {
|
||||
btn.disabled = false;
|
||||
btn.textContent = '📥 Import Token';
|
||||
}
|
||||
}
|
||||
|
||||
async function manualRefresh() {
|
||||
const btn = document.getElementById('refreshBtn');
|
||||
const statusEl = document.getElementById('refreshStatus');
|
||||
|
||||
btn.disabled = true;
|
||||
btn.innerHTML = '<span class="icon">⏳</span> Refreshing...';
|
||||
statusEl.className = 'status-message';
|
||||
statusEl.style.display = 'none';
|
||||
|
||||
try {
|
||||
const response = await fetch('/v0/oauth/kiro/refresh', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' }
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (response.ok && data.success) {
|
||||
statusEl.className = 'status-message success';
|
||||
let msg = '✅ ' + data.message;
|
||||
if (data.warnings && data.warnings.length > 0) {
|
||||
msg += ' (Warnings: ' + data.warnings.join('; ') + ')';
|
||||
}
|
||||
statusEl.textContent = msg;
|
||||
} else {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = '❌ ' + (data.error || data.message || 'Refresh failed');
|
||||
}
|
||||
} catch (error) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = '❌ Network error: ' + error.message;
|
||||
} finally {
|
||||
btn.disabled = false;
|
||||
btn.innerHTML = '<span class="icon">🔄</span> Manual Refresh All Tokens';
|
||||
}
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>`
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user