mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-01-26 08:51:05 +00:00
feat: implement session_jwt and enhance auth
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import os
|
||||
import platform
|
||||
import uuid
|
||||
|
||||
import dotenv
|
||||
from flask import Flask, jsonify, redirect, request
|
||||
from jose import jwt
|
||||
|
||||
from application.auth import get_or_create_user_id, handle_auth
|
||||
from application.auth import handle_auth
|
||||
|
||||
from application.core.logging_config import setup_logging
|
||||
|
||||
@@ -38,10 +40,22 @@ app.config.update(
|
||||
celery.config_from_object("application.celeryconfig")
|
||||
api.init_app(app)
|
||||
|
||||
if settings.AUTH_TYPE in ("simple_jwt", "session_jwt") and not settings.JWT_SECRET_KEY:
|
||||
key_file = ".jwt_secret_key"
|
||||
try:
|
||||
with open(key_file, "r") as f:
|
||||
settings.JWT_SECRET_KEY = f.read().strip()
|
||||
except FileNotFoundError:
|
||||
new_key = os.urandom(32).hex()
|
||||
with open(key_file, "w") as f:
|
||||
f.write(new_key)
|
||||
settings.JWT_SECRET_KEY = new_key
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to setup JWT_SECRET_KEY: {e}")
|
||||
|
||||
SIMPLE_JWT_TOKEN = None
|
||||
if settings.AUTH_TYPE == "simple_jwt":
|
||||
user_id = get_or_create_user_id()
|
||||
payload = {"sub": user_id}
|
||||
payload = {"sub": "local"}
|
||||
SIMPLE_JWT_TOKEN = jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm="HS256")
|
||||
print(f"Generated Simple JWT Token: {SIMPLE_JWT_TOKEN}")
|
||||
|
||||
@@ -54,13 +68,33 @@ def home():
|
||||
return "Welcome to DocsGPT Backend!"
|
||||
|
||||
|
||||
@app.route("/api/config")
|
||||
def get_config():
|
||||
response = {
|
||||
"auth_type": settings.AUTH_TYPE,
|
||||
"requires_auth": settings.AUTH_TYPE in ["simple_jwt", "session_jwt"],
|
||||
}
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
@app.route("/api/generate_token")
|
||||
def generate_token():
|
||||
if settings.AUTH_TYPE == "session_jwt":
|
||||
new_user_id = str(uuid.uuid4())
|
||||
token = jwt.encode(
|
||||
{"sub": new_user_id}, settings.JWT_SECRET_KEY, algorithm="HS256"
|
||||
)
|
||||
return jsonify({"token": token})
|
||||
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
|
||||
|
||||
|
||||
@app.before_request
|
||||
def authenticate_request():
|
||||
if request.method == "OPTIONS":
|
||||
return "", 200
|
||||
|
||||
decoded_token = handle_auth(request)
|
||||
if "message" in decoded_token:
|
||||
if not decoded_token:
|
||||
request.decoded_token = None
|
||||
elif "error" in decoded_token:
|
||||
return jsonify(decoded_token), 401
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import uuid
|
||||
|
||||
from jose import jwt
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
def handle_auth(request, data={}):
|
||||
if settings.AUTH_TYPE == "simple_jwt":
|
||||
if settings.AUTH_TYPE in ["simple_jwt", "session_jwt"]:
|
||||
jwt_token = request.headers.get("Authorization")
|
||||
if not jwt_token:
|
||||
return {"message": "Missing Authorization header"}
|
||||
return None
|
||||
|
||||
jwt_token = jwt_token.replace("Bearer ", "")
|
||||
|
||||
@@ -22,18 +20,9 @@ def handle_auth(request, data={}):
|
||||
)
|
||||
return decoded_token
|
||||
except Exception as e:
|
||||
return {"message": f"Authentication error: {str(e)}"}
|
||||
return {
|
||||
"message": f"Authentication error: {str(e)}",
|
||||
"error": "invalid_token",
|
||||
}
|
||||
else:
|
||||
return {"sub": "local"}
|
||||
|
||||
|
||||
def get_or_create_user_id():
|
||||
try:
|
||||
with open(settings.USER_ID_FILE, "r") as f:
|
||||
user_id = f.read().strip()
|
||||
return user_id
|
||||
except FileNotFoundError:
|
||||
user_id = str(uuid.uuid4())
|
||||
with open(settings.USER_ID_FILE, "w") as f:
|
||||
f.write(user_id)
|
||||
return user_id
|
||||
|
||||
@@ -100,7 +100,6 @@ class Settings(BaseSettings):
|
||||
FLASK_DEBUG_MODE: bool = False
|
||||
|
||||
JWT_SECRET_KEY: str = ""
|
||||
USER_ID_FILE: str = os.path.join(current_dir, "user_id.txt")
|
||||
|
||||
|
||||
path = Path(__file__).parent.parent.absolute()
|
||||
|
||||
@@ -1,15 +1,30 @@
|
||||
import { Routes, Route } from 'react-router-dom';
|
||||
import Navigation from './Navigation';
|
||||
import Conversation from './conversation/Conversation';
|
||||
import About from './About';
|
||||
import PageNotFound from './PageNotFound';
|
||||
import { useMediaQuery } from './hooks';
|
||||
import { useState } from 'react';
|
||||
import Setting from './settings';
|
||||
import './locale/i18n';
|
||||
import { Outlet } from 'react-router-dom';
|
||||
|
||||
import { useState } from 'react';
|
||||
import { Outlet, Route, Routes } from 'react-router-dom';
|
||||
|
||||
import About from './About';
|
||||
import Spinner from './components/Spinner';
|
||||
import Conversation from './conversation/Conversation';
|
||||
import { SharedConversation } from './conversation/SharedConversation';
|
||||
import { useDarkTheme } from './hooks';
|
||||
import { useDarkTheme, useMediaQuery } from './hooks';
|
||||
import useTokenAuth from './hooks/useTokenAuth';
|
||||
import Navigation from './Navigation';
|
||||
import PageNotFound from './PageNotFound';
|
||||
import Setting from './settings';
|
||||
|
||||
function AuthWrapper({ children }: { children: React.ReactNode }) {
|
||||
const { isAuthLoading } = useTokenAuth();
|
||||
|
||||
if (isAuthLoading) {
|
||||
return (
|
||||
<div className="h-screen flex items-center justify-center">
|
||||
<Spinner />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
function MainLayout() {
|
||||
const { isMobile } = useMediaQuery();
|
||||
@@ -39,7 +54,13 @@ export default function App() {
|
||||
return (
|
||||
<div className="h-full relative overflow-auto">
|
||||
<Routes>
|
||||
<Route element={<MainLayout />}>
|
||||
<Route
|
||||
element={
|
||||
<AuthWrapper>
|
||||
<MainLayout />
|
||||
</AuthWrapper>
|
||||
}
|
||||
>
|
||||
<Route index element={<Conversation />} />
|
||||
<Route path="/about" element={<About />} />
|
||||
<Route path="/settings" element={<Setting />} />
|
||||
|
||||
@@ -28,6 +28,7 @@ import {
|
||||
import ConversationTile from './conversation/ConversationTile';
|
||||
import { useDarkTheme, useMediaQuery } from './hooks';
|
||||
import useDefaultDocument from './hooks/useDefaultDocument';
|
||||
import useTokenAuth from './hooks/useTokenAuth';
|
||||
import DeleteConvModal from './modals/DeleteConvModal';
|
||||
import JWTModal from './modals/JWTModal';
|
||||
import { ActiveState, Doc } from './models/misc';
|
||||
@@ -72,10 +73,10 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
const { t } = useTranslation();
|
||||
const isApiKeySet = useSelector(selectApiKeyStatus);
|
||||
|
||||
const { showTokenModal, handleTokenSubmit } = useTokenAuth();
|
||||
|
||||
const [uploadModalState, setUploadModalState] =
|
||||
useState<ActiveState>('INACTIVE');
|
||||
const [authKeyModalState, setAuthKeyModalState] =
|
||||
useState<ActiveState>('INACTIVE');
|
||||
|
||||
const navRef = useRef(null);
|
||||
|
||||
@@ -204,15 +205,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
setNavOpen(!isMobile);
|
||||
}, [isMobile]);
|
||||
|
||||
useEffect(() => {
|
||||
const authToken = localStorage.getItem('authToken');
|
||||
if (!authToken) {
|
||||
setAuthKeyModalState('ACTIVE');
|
||||
}
|
||||
}, []);
|
||||
|
||||
useDefaultDocument();
|
||||
|
||||
return (
|
||||
<>
|
||||
{!navOpen && (
|
||||
@@ -485,8 +478,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
></Upload>
|
||||
)}
|
||||
<JWTModal
|
||||
modalState={authKeyModalState}
|
||||
setModalState={setAuthKeyModalState}
|
||||
modalState={showTokenModal ? 'ACTIVE' : 'INACTIVE'}
|
||||
handleTokenSubmit={handleTokenSubmit}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
const endpoints = {
|
||||
USER: {
|
||||
CONFIG: '/api/config',
|
||||
NEW_TOKEN: '/api/generate_token',
|
||||
DOCS: '/api/sources',
|
||||
DOCS_CHECK: '/api/docs_check',
|
||||
DOCS_PAGINATED: '/api/sources/paginated',
|
||||
|
||||
@@ -2,6 +2,9 @@ import apiClient from '../client';
|
||||
import endpoints from '../endpoints';
|
||||
|
||||
const userService = {
|
||||
getConfig: (): Promise<any> => apiClient.get(endpoints.USER.CONFIG, null),
|
||||
getNewToken: (): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.NEW_TOKEN, null),
|
||||
getDocs: (token: string | null): Promise<any> =>
|
||||
apiClient.get(`${endpoints.USER.DOCS}`, token),
|
||||
getDocsWithPagination: (query: string, token: string | null): Promise<any> =>
|
||||
|
||||
55
frontend/src/hooks/useTokenAuth.ts
Normal file
55
frontend/src/hooks/useTokenAuth.ts
Normal file
@@ -0,0 +1,55 @@
|
||||
import { useEffect, useRef, useState } from 'react';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import { selectToken, setToken } from '../preferences/preferenceSlice';
|
||||
|
||||
export default function useAuth() {
|
||||
const dispatch = useDispatch();
|
||||
const token = useSelector(selectToken);
|
||||
const [authType, setAuthType] = useState(null);
|
||||
const [showTokenModal, setShowTokenModal] = useState(false);
|
||||
const [isAuthLoading, setIsAuthLoading] = useState(true);
|
||||
const isGeneratingToken = useRef(false);
|
||||
|
||||
const generateNewToken = async () => {
|
||||
if (isGeneratingToken.current) return;
|
||||
isGeneratingToken.current = true;
|
||||
const response = await userService.getNewToken();
|
||||
const { token: newToken } = await response.json();
|
||||
localStorage.setItem('authToken', newToken);
|
||||
dispatch(setToken(newToken));
|
||||
setIsAuthLoading(false);
|
||||
return newToken;
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const initializeAuth = async () => {
|
||||
try {
|
||||
const configRes = await userService.getConfig();
|
||||
const config = await configRes.json();
|
||||
setAuthType(config.auth_type);
|
||||
|
||||
if (config.auth_type === 'session_jwt' && !token) {
|
||||
await generateNewToken();
|
||||
} else if (config.auth_type === 'simple_jwt' && !token) {
|
||||
setShowTokenModal(true);
|
||||
setIsAuthLoading(false);
|
||||
} else {
|
||||
setIsAuthLoading(false);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Auth initialization failed:', error);
|
||||
setIsAuthLoading(false);
|
||||
}
|
||||
};
|
||||
initializeAuth();
|
||||
}, []);
|
||||
|
||||
const handleTokenSubmit = (enteredToken: string) => {
|
||||
localStorage.setItem('authToken', enteredToken);
|
||||
dispatch(setToken(enteredToken));
|
||||
setShowTokenModal(false);
|
||||
};
|
||||
return { authType, showTokenModal, isAuthLoading, token, handleTokenSubmit };
|
||||
}
|
||||
@@ -3,30 +3,23 @@ import { useDispatch } from 'react-redux';
|
||||
|
||||
import Input from '../components/Input';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import { setToken } from '../preferences/preferenceSlice';
|
||||
import WrapperModal from './WrapperModal';
|
||||
|
||||
type JWTModalProps = {
|
||||
modalState: ActiveState;
|
||||
setModalState: (state: ActiveState) => void;
|
||||
handleTokenSubmit: (enteredToken: string) => void;
|
||||
};
|
||||
|
||||
export default function JWTModal({ modalState, setModalState }: JWTModalProps) {
|
||||
const dispatch = useDispatch();
|
||||
export default function JWTModal({
|
||||
modalState,
|
||||
handleTokenSubmit,
|
||||
}: JWTModalProps) {
|
||||
const [jwtToken, setJwtToken] = useState<string>('');
|
||||
|
||||
const handleSaveToken = () => {
|
||||
if (jwtToken) {
|
||||
localStorage.setItem('authToken', jwtToken);
|
||||
dispatch(setToken(jwtToken));
|
||||
setModalState('INACTIVE');
|
||||
}
|
||||
};
|
||||
|
||||
if (modalState !== 'ACTIVE') return null;
|
||||
|
||||
return (
|
||||
<WrapperModal close={() => setModalState('INACTIVE')} className="p-4">
|
||||
<WrapperModal className="p-4" isPerformingTask={true} close={() => {}}>
|
||||
<div className="mb-6">
|
||||
<span className="text-lg text-jet dark:text-bright-gray">
|
||||
Add JWT Token
|
||||
@@ -44,7 +37,7 @@ export default function JWTModal({ modalState, setModalState }: JWTModalProps) {
|
||||
</div>
|
||||
<button
|
||||
disabled={jwtToken.length === 0}
|
||||
onClick={handleSaveToken}
|
||||
onClick={handleTokenSubmit.bind(null, jwtToken)}
|
||||
className="float-right mt-4 rounded-full bg-purple-30 px-5 py-2 text-sm text-white hover:bg-[#6F3FD1] disabled:opacity-50"
|
||||
>
|
||||
Save Token
|
||||
|
||||
Reference in New Issue
Block a user