feat: pass decoded_token to llm and retrievers

This commit is contained in:
Siddhant Rai
2025-03-18 23:46:02 +05:30
parent f4ab85a2bb
commit ab95d90284
9 changed files with 75 additions and 25 deletions

View File

@@ -9,10 +9,21 @@ from application.llm.llm_creator import LLMCreator
class BaseAgent:
def __init__(self, endpoint, llm_name, gpt_model, api_key, user_api_key=None):
def __init__(
self,
endpoint,
llm_name,
gpt_model,
api_key,
user_api_key=None,
decoded_token=None,
):
self.endpoint = endpoint
self.llm = LLMCreator.create_llm(
llm_name, api_key=api_key, user_api_key=user_api_key
llm_name,
api_key=api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
self.llm_handler = get_llm_handler(llm_name)
self.gpt_model = gpt_model

View File

@@ -17,8 +17,12 @@ class ClassicAgent(BaseAgent):
user_api_key=None,
prompt="",
chat_history=None,
decoded_token=None,
):
super().__init__(endpoint, llm_name, gpt_model, api_key, user_api_key)
super().__init__(
endpoint, llm_name, gpt_model, api_key, user_api_key, decoded_token
)
self.user = decoded_token.get("sub")
self.prompt = prompt
self.chat_history = chat_history if chat_history is not None else []
@@ -73,7 +77,7 @@ class ClassicAgent(BaseAgent):
)
messages_combine.append({"role": "user", "content": query})
tools_dict = self._get_user_tools()
tools_dict = self._get_user_tools(self.user)
self._prepare_tools(tools_dict)
resp = self._llm_gen(messages_combine, log_context)

View File

@@ -264,7 +264,10 @@ def complete_stream(
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
settings.LLM_NAME,
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
if should_save_conversation:
@@ -420,6 +423,7 @@ class Stream(Resource):
user_api_key=user_api_key,
prompt=prompt,
chat_history=history,
decoded_token=decoded_token,
)
retriever = RetrieverCreator.create_retriever(
@@ -431,6 +435,7 @@ class Stream(Resource):
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
return Response(
@@ -565,6 +570,7 @@ class Answer(Resource):
user_api_key=user_api_key,
prompt=prompt,
chat_history=history,
decoded_token=decoded_token,
)
retriever = RetrieverCreator.create_retriever(
@@ -576,6 +582,7 @@ class Answer(Resource):
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
response_full = ""
@@ -623,7 +630,10 @@ class Answer(Resource):
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
settings.LLM_NAME,
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
result = {"answer": response_full, "sources": source_log_docs}
@@ -743,6 +753,7 @@ class Search(Resource):
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
docs = retriever.search(question)

View File

@@ -5,7 +5,8 @@ from application.usage import gen_token_usage, stream_token_usage
class BaseLLM(ABC):
def __init__(self):
def __init__(self, decoded_token):
self.decoded_token = decoded_token
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
def _apply_decorator(self, method, decorators, *args, **kwargs):

View File

@@ -9,6 +9,7 @@ from application.llm.premai import PremAILLM
from application.llm.google_ai import GoogleLLM
from application.llm.novita import NovitaLLM
class LLMCreator:
llms = {
"openai": OpenAILLM,
@@ -21,12 +22,14 @@ class LLMCreator:
"premai": PremAILLM,
"groq": GroqLLM,
"google": GoogleLLM,
"novita": NovitaLLM
"novita": NovitaLLM,
}
@classmethod
def create_llm(cls, type, api_key, user_api_key, *args, **kwargs):
def create_llm(cls, type, api_key, user_api_key, decoded_token, *args, **kwargs):
llm_class = cls.llms.get(type.lower())
if not llm_class:
raise ValueError(f"No LLM class found for type {type}")
return llm_class(api_key, user_api_key, *args, **kwargs)
return llm_class(
api_key, user_api_key, decoded_token=decoded_token, *args, **kwargs
)

View File

@@ -17,6 +17,7 @@ class BraveRetSearch(BaseRetriever):
token_limit=150,
gpt_model="docsgpt",
user_api_key=None,
decoded_token=None,
):
self.question = question
self.source = source
@@ -35,6 +36,7 @@ class BraveRetSearch(BaseRetriever):
)
)
self.user_api_key = user_api_key
self.decoded_token = decoded_token
def _get_data(self):
if self.chunks == 0:
@@ -81,7 +83,10 @@ class BraveRetSearch(BaseRetriever):
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
settings.LLM_NAME,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
decoded_token=self.decoded_token,
)
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
@@ -100,5 +105,5 @@ class BraveRetSearch(BaseRetriever):
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,
"user_api_key": self.user_api_key
"user_api_key": self.user_api_key,
}

View File

@@ -17,6 +17,7 @@ class ClassicRAG(BaseRetriever):
user_api_key=None,
llm_name=settings.LLM_NAME,
api_key=settings.API_KEY,
decoded_token=None,
):
self.original_question = ""
self.chat_history = chat_history if chat_history is not None else []
@@ -37,10 +38,14 @@ class ClassicRAG(BaseRetriever):
self.llm_name = llm_name
self.api_key = api_key
self.llm = LLMCreator.create_llm(
self.llm_name, api_key=self.api_key, user_api_key=self.user_api_key
self.llm_name,
api_key=self.api_key,
user_api_key=self.user_api_key,
decoded_token=decoded_token,
)
self.question = self._rephrase_query()
self.vectorstore = source["active_docs"] if "active_docs" in source else None
self.decoded_token = decoded_token
def _rephrase_query(self):
if (

View File

@@ -17,6 +17,7 @@ class DuckDuckSearch(BaseRetriever):
token_limit=150,
gpt_model="docsgpt",
user_api_key=None,
decoded_token=None,
):
self.question = question
self.source = source
@@ -35,6 +36,7 @@ class DuckDuckSearch(BaseRetriever):
)
)
self.user_api_key = user_api_key
self.decoded_token = decoded_token
def _parse_lang_string(self, input_string):
result = []
@@ -88,17 +90,20 @@ class DuckDuckSearch(BaseRetriever):
for doc in docs:
yield {"source": doc}
if len(self.chat_history) > 0:
if len(self.chat_history) > 0:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append(
{"role": "assistant", "content": i["response"]}
)
if "prompt" in i and "response" in i:
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append(
{"role": "assistant", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
settings.LLM_NAME,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
decoded_token=self.decoded_token,
)
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
@@ -107,7 +112,7 @@ class DuckDuckSearch(BaseRetriever):
def search(self):
return self._get_data()
def get_params(self):
return {
"question": self.question,
@@ -117,5 +122,5 @@ class DuckDuckSearch(BaseRetriever):
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,
"user_api_key": self.user_api_key
"user_api_key": self.user_api_key,
}

View File

@@ -9,10 +9,15 @@ db = mongo["docsgpt"]
usage_collection = db["token_usage"]
def update_token_usage(user_api_key, token_usage):
def update_token_usage(decoded_token, user_api_key, token_usage):
if "pytest" in sys.modules:
return
if decoded_token:
user_id = decoded_token["sub"]
else:
user_id = None
usage_data = {
"user_id": user_id,
"api_key": user_api_key,
"prompt_tokens": token_usage["prompt_tokens"],
"generated_tokens": token_usage["generated_tokens"],
@@ -35,7 +40,7 @@ def gen_token_usage(func):
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(
result
)
update_token_usage(self.user_api_key, self.token_usage)
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
return result
return wrapper
@@ -54,6 +59,6 @@ def stream_token_usage(func):
yield r
for line in batch:
self.token_usage["generated_tokens"] += num_tokens_from_string(line)
update_token_usage(self.user_api_key, self.token_usage)
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
return wrapper