mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
feat: pass decoded_token to llm and retrievers
This commit is contained in:
@@ -9,10 +9,21 @@ from application.llm.llm_creator import LLMCreator
|
||||
|
||||
|
||||
class BaseAgent:
|
||||
def __init__(self, endpoint, llm_name, gpt_model, api_key, user_api_key=None):
|
||||
def __init__(
|
||||
self,
|
||||
endpoint,
|
||||
llm_name,
|
||||
gpt_model,
|
||||
api_key,
|
||||
user_api_key=None,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.llm = LLMCreator.create_llm(
|
||||
llm_name, api_key=api_key, user_api_key=user_api_key
|
||||
llm_name,
|
||||
api_key=api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
self.llm_handler = get_llm_handler(llm_name)
|
||||
self.gpt_model = gpt_model
|
||||
|
||||
@@ -17,8 +17,12 @@ class ClassicAgent(BaseAgent):
|
||||
user_api_key=None,
|
||||
prompt="",
|
||||
chat_history=None,
|
||||
decoded_token=None,
|
||||
):
|
||||
super().__init__(endpoint, llm_name, gpt_model, api_key, user_api_key)
|
||||
super().__init__(
|
||||
endpoint, llm_name, gpt_model, api_key, user_api_key, decoded_token
|
||||
)
|
||||
self.user = decoded_token.get("sub")
|
||||
self.prompt = prompt
|
||||
self.chat_history = chat_history if chat_history is not None else []
|
||||
|
||||
@@ -73,7 +77,7 @@ class ClassicAgent(BaseAgent):
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": query})
|
||||
|
||||
tools_dict = self._get_user_tools()
|
||||
tools_dict = self._get_user_tools(self.user)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
resp = self._llm_gen(messages_combine, log_context)
|
||||
|
||||
@@ -264,7 +264,10 @@ def complete_stream(
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
if should_save_conversation:
|
||||
@@ -420,6 +423,7 @@ class Stream(Resource):
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=history,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
@@ -431,6 +435,7 @@ class Stream(Resource):
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
return Response(
|
||||
@@ -565,6 +570,7 @@ class Answer(Resource):
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=history,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
@@ -576,6 +582,7 @@ class Answer(Resource):
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
response_full = ""
|
||||
@@ -623,7 +630,10 @@ class Answer(Resource):
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
result = {"answer": response_full, "sources": source_log_docs}
|
||||
@@ -743,6 +753,7 @@ class Search(Resource):
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
docs = retriever.search(question)
|
||||
|
||||
@@ -5,7 +5,8 @@ from application.usage import gen_token_usage, stream_token_usage
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
def __init__(self):
|
||||
def __init__(self, decoded_token):
|
||||
self.decoded_token = decoded_token
|
||||
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
|
||||
def _apply_decorator(self, method, decorators, *args, **kwargs):
|
||||
|
||||
@@ -9,6 +9,7 @@ from application.llm.premai import PremAILLM
|
||||
from application.llm.google_ai import GoogleLLM
|
||||
from application.llm.novita import NovitaLLM
|
||||
|
||||
|
||||
class LLMCreator:
|
||||
llms = {
|
||||
"openai": OpenAILLM,
|
||||
@@ -21,12 +22,14 @@ class LLMCreator:
|
||||
"premai": PremAILLM,
|
||||
"groq": GroqLLM,
|
||||
"google": GoogleLLM,
|
||||
"novita": NovitaLLM
|
||||
"novita": NovitaLLM,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_llm(cls, type, api_key, user_api_key, *args, **kwargs):
|
||||
def create_llm(cls, type, api_key, user_api_key, decoded_token, *args, **kwargs):
|
||||
llm_class = cls.llms.get(type.lower())
|
||||
if not llm_class:
|
||||
raise ValueError(f"No LLM class found for type {type}")
|
||||
return llm_class(api_key, user_api_key, *args, **kwargs)
|
||||
return llm_class(
|
||||
api_key, user_api_key, decoded_token=decoded_token, *args, **kwargs
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ class BraveRetSearch(BaseRetriever):
|
||||
token_limit=150,
|
||||
gpt_model="docsgpt",
|
||||
user_api_key=None,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.question = question
|
||||
self.source = source
|
||||
@@ -35,6 +36,7 @@ class BraveRetSearch(BaseRetriever):
|
||||
)
|
||||
)
|
||||
self.user_api_key = user_api_key
|
||||
self.decoded_token = decoded_token
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0:
|
||||
@@ -81,7 +83,10 @@ class BraveRetSearch(BaseRetriever):
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.user_api_key,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
@@ -100,5 +105,5 @@ class BraveRetSearch(BaseRetriever):
|
||||
"chunks": self.chunks,
|
||||
"token_limit": self.token_limit,
|
||||
"gpt_model": self.gpt_model,
|
||||
"user_api_key": self.user_api_key
|
||||
"user_api_key": self.user_api_key,
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ class ClassicRAG(BaseRetriever):
|
||||
user_api_key=None,
|
||||
llm_name=settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.original_question = ""
|
||||
self.chat_history = chat_history if chat_history is not None else []
|
||||
@@ -37,10 +38,14 @@ class ClassicRAG(BaseRetriever):
|
||||
self.llm_name = llm_name
|
||||
self.api_key = api_key
|
||||
self.llm = LLMCreator.create_llm(
|
||||
self.llm_name, api_key=self.api_key, user_api_key=self.user_api_key
|
||||
self.llm_name,
|
||||
api_key=self.api_key,
|
||||
user_api_key=self.user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
self.question = self._rephrase_query()
|
||||
self.vectorstore = source["active_docs"] if "active_docs" in source else None
|
||||
self.decoded_token = decoded_token
|
||||
|
||||
def _rephrase_query(self):
|
||||
if (
|
||||
|
||||
@@ -17,6 +17,7 @@ class DuckDuckSearch(BaseRetriever):
|
||||
token_limit=150,
|
||||
gpt_model="docsgpt",
|
||||
user_api_key=None,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.question = question
|
||||
self.source = source
|
||||
@@ -35,6 +36,7 @@ class DuckDuckSearch(BaseRetriever):
|
||||
)
|
||||
)
|
||||
self.user_api_key = user_api_key
|
||||
self.decoded_token = decoded_token
|
||||
|
||||
def _parse_lang_string(self, input_string):
|
||||
result = []
|
||||
@@ -88,17 +90,20 @@ class DuckDuckSearch(BaseRetriever):
|
||||
for doc in docs:
|
||||
yield {"source": doc}
|
||||
|
||||
if len(self.chat_history) > 0:
|
||||
if len(self.chat_history) > 0:
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": i["response"]}
|
||||
)
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": i["response"]}
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.user_api_key,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
@@ -107,7 +112,7 @@ class DuckDuckSearch(BaseRetriever):
|
||||
|
||||
def search(self):
|
||||
return self._get_data()
|
||||
|
||||
|
||||
def get_params(self):
|
||||
return {
|
||||
"question": self.question,
|
||||
@@ -117,5 +122,5 @@ class DuckDuckSearch(BaseRetriever):
|
||||
"chunks": self.chunks,
|
||||
"token_limit": self.token_limit,
|
||||
"gpt_model": self.gpt_model,
|
||||
"user_api_key": self.user_api_key
|
||||
"user_api_key": self.user_api_key,
|
||||
}
|
||||
|
||||
@@ -9,10 +9,15 @@ db = mongo["docsgpt"]
|
||||
usage_collection = db["token_usage"]
|
||||
|
||||
|
||||
def update_token_usage(user_api_key, token_usage):
|
||||
def update_token_usage(decoded_token, user_api_key, token_usage):
|
||||
if "pytest" in sys.modules:
|
||||
return
|
||||
if decoded_token:
|
||||
user_id = decoded_token["sub"]
|
||||
else:
|
||||
user_id = None
|
||||
usage_data = {
|
||||
"user_id": user_id,
|
||||
"api_key": user_api_key,
|
||||
"prompt_tokens": token_usage["prompt_tokens"],
|
||||
"generated_tokens": token_usage["generated_tokens"],
|
||||
@@ -35,7 +40,7 @@ def gen_token_usage(func):
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(
|
||||
result
|
||||
)
|
||||
update_token_usage(self.user_api_key, self.token_usage)
|
||||
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
@@ -54,6 +59,6 @@ def stream_token_usage(func):
|
||||
yield r
|
||||
for line in batch:
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_string(line)
|
||||
update_token_usage(self.user_api_key, self.token_usage)
|
||||
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
|
||||
|
||||
return wrapper
|
||||
|
||||
Reference in New Issue
Block a user