mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-22 20:32:11 +00:00
fix: token calc (#2285)
This commit is contained in:
@@ -13,10 +13,12 @@ class BaseLLM(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoded_token=None,
|
||||
agent_id=None,
|
||||
model_id=None,
|
||||
base_url=None,
|
||||
):
|
||||
self.decoded_token = decoded_token
|
||||
self.agent_id = str(agent_id) if agent_id else None
|
||||
self.model_id = model_id
|
||||
self.base_url = base_url
|
||||
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
@@ -33,9 +35,10 @@ class BaseLLM(ABC):
|
||||
self._fallback_llm = LLMCreator.create_llm(
|
||||
settings.FALLBACK_LLM_PROVIDER,
|
||||
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
|
||||
user_api_key=None,
|
||||
user_api_key=getattr(self, "user_api_key", None),
|
||||
decoded_token=self.decoded_token,
|
||||
model_id=settings.FALLBACK_LLM_NAME,
|
||||
agent_id=self.agent_id,
|
||||
)
|
||||
logger.info(
|
||||
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
|
||||
|
||||
@@ -13,7 +13,7 @@ class GoogleLLM(BaseLLM):
|
||||
def __init__(
|
||||
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(decoded_token=decoded_token, *args, **kwargs)
|
||||
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
|
||||
@@ -567,6 +567,7 @@ class LLMHandler(ABC):
|
||||
getattr(agent, "user_api_key", None),
|
||||
getattr(agent, "decoded_token", None),
|
||||
model_id=compression_model,
|
||||
agent_id=getattr(agent, "agent_id", None),
|
||||
)
|
||||
|
||||
# Create service without DB persistence capability
|
||||
|
||||
@@ -31,7 +31,15 @@ class LLMCreator:
|
||||
|
||||
@classmethod
|
||||
def create_llm(
|
||||
cls, type, api_key, user_api_key, decoded_token, model_id=None, *args, **kwargs
|
||||
cls,
|
||||
type,
|
||||
api_key,
|
||||
user_api_key,
|
||||
decoded_token,
|
||||
model_id=None,
|
||||
agent_id=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
from application.core.model_utils import get_base_url_for_model
|
||||
|
||||
@@ -49,6 +57,7 @@ class LLMCreator:
|
||||
user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
base_url=base_url,
|
||||
*args,
|
||||
**kwargs,
|
||||
|
||||
Reference in New Issue
Block a user