diff --git a/application/agents/base.py b/application/agents/base.py index 7e36c991..d0f972a9 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -9,10 +9,21 @@ from application.llm.llm_creator import LLMCreator class BaseAgent: - def __init__(self, endpoint, llm_name, gpt_model, api_key, user_api_key=None): + def __init__( + self, + endpoint, + llm_name, + gpt_model, + api_key, + user_api_key=None, + decoded_token=None, + ): self.endpoint = endpoint self.llm = LLMCreator.create_llm( - llm_name, api_key=api_key, user_api_key=user_api_key + llm_name, + api_key=api_key, + user_api_key=user_api_key, + decoded_token=decoded_token, ) self.llm_handler = get_llm_handler(llm_name) self.gpt_model = gpt_model diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py index 8848c6f6..2752c833 100644 --- a/application/agents/classic_agent.py +++ b/application/agents/classic_agent.py @@ -17,8 +17,12 @@ class ClassicAgent(BaseAgent): user_api_key=None, prompt="", chat_history=None, + decoded_token=None, ): - super().__init__(endpoint, llm_name, gpt_model, api_key, user_api_key) + super().__init__( + endpoint, llm_name, gpt_model, api_key, user_api_key, decoded_token + ) + self.user = decoded_token.get("sub") self.prompt = prompt self.chat_history = chat_history if chat_history is not None else [] @@ -73,7 +77,7 @@ class ClassicAgent(BaseAgent): ) messages_combine.append({"role": "user", "content": query}) - tools_dict = self._get_user_tools() + tools_dict = self._get_user_tools(self.user) self._prepare_tools(tools_dict) resp = self._llm_gen(messages_combine, log_context) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index faab9e08..34081784 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -264,7 +264,10 @@ def complete_stream( doc["source"] = "None" llm = LLMCreator.create_llm( - settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key + settings.LLM_NAME, + api_key=settings.API_KEY, + user_api_key=user_api_key, + decoded_token=decoded_token, ) if should_save_conversation: @@ -420,6 +423,7 @@ class Stream(Resource): user_api_key=user_api_key, prompt=prompt, chat_history=history, + decoded_token=decoded_token, ) retriever = RetrieverCreator.create_retriever( @@ -431,6 +435,7 @@ class Stream(Resource): token_limit=token_limit, gpt_model=gpt_model, user_api_key=user_api_key, + decoded_token=decoded_token, ) return Response( @@ -565,6 +570,7 @@ class Answer(Resource): user_api_key=user_api_key, prompt=prompt, chat_history=history, + decoded_token=decoded_token, ) retriever = RetrieverCreator.create_retriever( @@ -576,6 +582,7 @@ class Answer(Resource): token_limit=token_limit, gpt_model=gpt_model, user_api_key=user_api_key, + decoded_token=decoded_token, ) response_full = "" @@ -623,7 +630,10 @@ class Answer(Resource): doc["source"] = "None" llm = LLMCreator.create_llm( - settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key + settings.LLM_NAME, + api_key=settings.API_KEY, + user_api_key=user_api_key, + decoded_token=decoded_token, ) result = {"answer": response_full, "sources": source_log_docs} @@ -743,6 +753,7 @@ class Search(Resource): token_limit=token_limit, gpt_model=gpt_model, user_api_key=user_api_key, + decoded_token=decoded_token, ) docs = retriever.search(question) diff --git a/application/llm/base.py b/application/llm/base.py index e687e567..39c69499 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -5,7 +5,8 @@ from application.usage import gen_token_usage, stream_token_usage class BaseLLM(ABC): - def __init__(self): + def __init__(self, decoded_token): + self.decoded_token = decoded_token self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0} def _apply_decorator(self, method, decorators, *args, **kwargs): diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py index 9f1305ba..3ed23854 100644 --- a/application/llm/llm_creator.py +++ b/application/llm/llm_creator.py @@ -9,6 +9,7 @@ from application.llm.premai import PremAILLM from application.llm.google_ai import GoogleLLM from application.llm.novita import NovitaLLM + class LLMCreator: llms = { "openai": OpenAILLM, @@ -21,12 +22,14 @@ class LLMCreator: "premai": PremAILLM, "groq": GroqLLM, "google": GoogleLLM, - "novita": NovitaLLM + "novita": NovitaLLM, } @classmethod - def create_llm(cls, type, api_key, user_api_key, *args, **kwargs): + def create_llm(cls, type, api_key, user_api_key, decoded_token, *args, **kwargs): llm_class = cls.llms.get(type.lower()) if not llm_class: raise ValueError(f"No LLM class found for type {type}") - return llm_class(api_key, user_api_key, *args, **kwargs) + return llm_class( + api_key, user_api_key, decoded_token=decoded_token, *args, **kwargs + ) diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py index 08b16bc0..ed490734 100644 --- a/application/retriever/brave_search.py +++ b/application/retriever/brave_search.py @@ -17,6 +17,7 @@ class BraveRetSearch(BaseRetriever): token_limit=150, gpt_model="docsgpt", user_api_key=None, + decoded_token=None, ): self.question = question self.source = source @@ -35,6 +36,7 @@ class BraveRetSearch(BaseRetriever): ) ) self.user_api_key = user_api_key + self.decoded_token = decoded_token def _get_data(self): if self.chunks == 0: @@ -81,7 +83,10 @@ class BraveRetSearch(BaseRetriever): messages_combine.append({"role": "user", "content": self.question}) llm = LLMCreator.create_llm( - settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + settings.LLM_NAME, + api_key=settings.API_KEY, + user_api_key=self.user_api_key, + decoded_token=self.decoded_token, ) completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) @@ -100,5 +105,5 @@ class BraveRetSearch(BaseRetriever): "chunks": self.chunks, "token_limit": self.token_limit, "gpt_model": self.gpt_model, - "user_api_key": self.user_api_key + "user_api_key": self.user_api_key, } diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 03f17f44..08771337 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -17,6 +17,7 @@ class ClassicRAG(BaseRetriever): user_api_key=None, llm_name=settings.LLM_NAME, api_key=settings.API_KEY, + decoded_token=None, ): self.original_question = "" self.chat_history = chat_history if chat_history is not None else [] @@ -37,10 +38,14 @@ class ClassicRAG(BaseRetriever): self.llm_name = llm_name self.api_key = api_key self.llm = LLMCreator.create_llm( - self.llm_name, api_key=self.api_key, user_api_key=self.user_api_key + self.llm_name, + api_key=self.api_key, + user_api_key=self.user_api_key, + decoded_token=decoded_token, ) self.question = self._rephrase_query() self.vectorstore = source["active_docs"] if "active_docs" in source else None + self.decoded_token = decoded_token def _rephrase_query(self): if ( diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py index c6386410..9ce73995 100644 --- a/application/retriever/duckduck_search.py +++ b/application/retriever/duckduck_search.py @@ -17,6 +17,7 @@ class DuckDuckSearch(BaseRetriever): token_limit=150, gpt_model="docsgpt", user_api_key=None, + decoded_token=None, ): self.question = question self.source = source @@ -35,6 +36,7 @@ class DuckDuckSearch(BaseRetriever): ) ) self.user_api_key = user_api_key + self.decoded_token = decoded_token def _parse_lang_string(self, input_string): result = [] @@ -88,17 +90,20 @@ class DuckDuckSearch(BaseRetriever): for doc in docs: yield {"source": doc} - if len(self.chat_history) > 0: + if len(self.chat_history) > 0: for i in self.chat_history: - if "prompt" in i and "response" in i: - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append( - {"role": "assistant", "content": i["response"]} - ) + if "prompt" in i and "response" in i: + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append( + {"role": "assistant", "content": i["response"]} + ) messages_combine.append({"role": "user", "content": self.question}) llm = LLMCreator.create_llm( - settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + settings.LLM_NAME, + api_key=settings.API_KEY, + user_api_key=self.user_api_key, + decoded_token=self.decoded_token, ) completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) @@ -107,7 +112,7 @@ class DuckDuckSearch(BaseRetriever): def search(self): return self._get_data() - + def get_params(self): return { "question": self.question, @@ -117,5 +122,5 @@ class DuckDuckSearch(BaseRetriever): "chunks": self.chunks, "token_limit": self.token_limit, "gpt_model": self.gpt_model, - "user_api_key": self.user_api_key + "user_api_key": self.user_api_key, } diff --git a/application/usage.py b/application/usage.py index a18a3848..85328c1f 100644 --- a/application/usage.py +++ b/application/usage.py @@ -9,10 +9,15 @@ db = mongo["docsgpt"] usage_collection = db["token_usage"] -def update_token_usage(user_api_key, token_usage): +def update_token_usage(decoded_token, user_api_key, token_usage): if "pytest" in sys.modules: return + if decoded_token: + user_id = decoded_token["sub"] + else: + user_id = None usage_data = { + "user_id": user_id, "api_key": user_api_key, "prompt_tokens": token_usage["prompt_tokens"], "generated_tokens": token_usage["generated_tokens"], @@ -35,7 +40,7 @@ def gen_token_usage(func): self.token_usage["generated_tokens"] += num_tokens_from_object_or_list( result ) - update_token_usage(self.user_api_key, self.token_usage) + update_token_usage(self.decoded_token, self.user_api_key, self.token_usage) return result return wrapper @@ -54,6 +59,6 @@ def stream_token_usage(func): yield r for line in batch: self.token_usage["generated_tokens"] += num_tokens_from_string(line) - update_token_usage(self.user_api_key, self.token_usage) + update_token_usage(self.decoded_token, self.user_api_key, self.token_usage) return wrapper