mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
fix: user_api_key capturing
This commit is contained in:
@@ -184,7 +184,9 @@ def complete_stream(question, retriever, conversation_id, user_api_key):
|
||||
elif "source" in line:
|
||||
source_log_docs.append(line["source"])
|
||||
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=user_api_key)
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
||||
)
|
||||
conversation_id = save_conversation(
|
||||
conversation_id, question, response_full, source_log_docs, llm
|
||||
)
|
||||
@@ -252,7 +254,7 @@ def stream():
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
gpt_model=gpt_model,
|
||||
api_key=user_api_key,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
|
||||
return Response(
|
||||
@@ -317,7 +319,7 @@ def api_answer():
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
gpt_model=gpt_model,
|
||||
api_key=user_api_key,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
source_log_docs = []
|
||||
response_full = ""
|
||||
@@ -327,7 +329,9 @@ def api_answer():
|
||||
elif "answer" in line:
|
||||
response_full += line["answer"]
|
||||
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=user_api_key)
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
|
||||
)
|
||||
|
||||
result = {"answer": response_full, "sources": source_log_docs}
|
||||
result["conversation_id"] = save_conversation(
|
||||
@@ -379,7 +383,7 @@ def api_search():
|
||||
prompt="default",
|
||||
chunks=chunks,
|
||||
gpt_model=gpt_model,
|
||||
api_key=user_api_key,
|
||||
user_api_key=user_api_key,
|
||||
)
|
||||
docs = retriever.search()
|
||||
return docs
|
||||
|
||||
@@ -4,13 +4,14 @@ from application.core.settings import settings
|
||||
|
||||
class AnthropicLLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, *args, **kwargs):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = (
|
||||
api_key or settings.ANTHROPIC_API_KEY
|
||||
) # If not provided, use a default from settings
|
||||
self.user_api_key = user_api_key
|
||||
self.anthropic = Anthropic(api_key=self.api_key)
|
||||
self.HUMAN_PROMPT = HUMAN_PROMPT
|
||||
self.AI_PROMPT = AI_PROMPT
|
||||
|
||||
@@ -5,9 +5,10 @@ import requests
|
||||
|
||||
class DocsGPTAPILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, *args, **kwargs):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
self.endpoint = "https://llm.docsgpt.co.uk"
|
||||
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, *args, **kwargs):
|
||||
|
||||
@@ -4,7 +4,13 @@ from application.llm.base import BaseLLM
|
||||
class HuggingFaceLLM(BaseLLM):
|
||||
|
||||
def __init__(
|
||||
self, api_key=None, llm_name="Arc53/DocsGPT-7B", q=False, *args, **kwargs
|
||||
self,
|
||||
api_key=None,
|
||||
user_api_key=None,
|
||||
llm_name="Arc53/DocsGPT-7B",
|
||||
q=False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
global hf
|
||||
|
||||
@@ -37,6 +43,7 @@ class HuggingFaceLLM(BaseLLM):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
pipe = pipeline(
|
||||
"text-generation",
|
||||
model=model,
|
||||
|
||||
@@ -4,7 +4,14 @@ from application.core.settings import settings
|
||||
|
||||
class LlamaCpp(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, llm_name=settings.MODEL_PATH, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
api_key=None,
|
||||
user_api_key=None,
|
||||
llm_name=settings.MODEL_PATH,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
global llama
|
||||
try:
|
||||
from llama_cpp import Llama
|
||||
@@ -15,6 +22,7 @@ class LlamaCpp(BaseLLM):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
llama = Llama(model_path=llm_name, n_ctx=2048)
|
||||
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
|
||||
|
||||
@@ -20,8 +20,8 @@ class LLMCreator:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_llm(cls, type, api_key, *args, **kwargs):
|
||||
def create_llm(cls, type, api_key, user_api_key, *args, **kwargs):
|
||||
llm_class = cls.llms.get(type.lower())
|
||||
if not llm_class:
|
||||
raise ValueError(f"No LLM class found for type {type}")
|
||||
return llm_class(api_key, *args, **kwargs)
|
||||
return llm_class(api_key, user_api_key, *args, **kwargs)
|
||||
|
||||
@@ -4,7 +4,7 @@ from application.core.settings import settings
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, *args, **kwargs):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
global openai
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -13,6 +13,7 @@ class OpenAILLM(BaseLLM):
|
||||
api_key=api_key,
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _get_openai(self):
|
||||
# Import openai when needed
|
||||
|
||||
@@ -4,12 +4,13 @@ from application.core.settings import settings
|
||||
|
||||
class PremAILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, *args, **kwargs):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
from premai import Prem
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client = Prem(api_key=api_key)
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
self.project_id = settings.PREMAI_PROJECT_ID
|
||||
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
|
||||
|
||||
@@ -60,7 +60,7 @@ class LineIterator:
|
||||
|
||||
class SagemakerAPILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, *args, **kwargs):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
import boto3
|
||||
|
||||
runtime = boto3.client(
|
||||
@@ -72,6 +72,7 @@ class SagemakerAPILLM(BaseLLM):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
self.endpoint = settings.SAGEMAKER_ENDPOINT
|
||||
self.runtime = runtime
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ class BraveRetSearch(BaseRetriever):
|
||||
prompt,
|
||||
chunks=2,
|
||||
gpt_model="docsgpt",
|
||||
api_key=None,
|
||||
user_api_key=None,
|
||||
):
|
||||
self.question = question
|
||||
self.source = source
|
||||
@@ -24,7 +24,7 @@ class BraveRetSearch(BaseRetriever):
|
||||
self.prompt = prompt
|
||||
self.chunks = chunks
|
||||
self.gpt_model = gpt_model
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0:
|
||||
@@ -83,7 +83,9 @@ class BraveRetSearch(BaseRetriever):
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key)
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
|
||||
)
|
||||
|
||||
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
for line in completion:
|
||||
|
||||
@@ -17,7 +17,7 @@ class ClassicRAG(BaseRetriever):
|
||||
prompt,
|
||||
chunks=2,
|
||||
gpt_model="docsgpt",
|
||||
api_key=None,
|
||||
user_api_key=None,
|
||||
):
|
||||
self.question = question
|
||||
self.vectorstore = self._get_vectorstore(source=source)
|
||||
@@ -25,7 +25,7 @@ class ClassicRAG(BaseRetriever):
|
||||
self.prompt = prompt
|
||||
self.chunks = chunks
|
||||
self.gpt_model = gpt_model
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _get_vectorstore(self, source):
|
||||
if "active_docs" in source:
|
||||
@@ -98,7 +98,9 @@ class ClassicRAG(BaseRetriever):
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key)
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
|
||||
)
|
||||
|
||||
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
for line in completion:
|
||||
|
||||
@@ -16,7 +16,7 @@ class DuckDuckSearch(BaseRetriever):
|
||||
prompt,
|
||||
chunks=2,
|
||||
gpt_model="docsgpt",
|
||||
api_key=None,
|
||||
user_api_key=None,
|
||||
):
|
||||
self.question = question
|
||||
self.source = source
|
||||
@@ -24,7 +24,7 @@ class DuckDuckSearch(BaseRetriever):
|
||||
self.prompt = prompt
|
||||
self.chunks = chunks
|
||||
self.gpt_model = gpt_model
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _parse_lang_string(self, input_string):
|
||||
result = []
|
||||
@@ -100,7 +100,9 @@ class DuckDuckSearch(BaseRetriever):
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key)
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
|
||||
)
|
||||
|
||||
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
for line in completion:
|
||||
|
||||
@@ -9,11 +9,11 @@ db = mongo["docsgpt"]
|
||||
usage_collection = db["token_usage"]
|
||||
|
||||
|
||||
def update_token_usage(api_key, token_usage):
|
||||
def update_token_usage(user_api_key, token_usage):
|
||||
if "pytest" in sys.modules:
|
||||
return
|
||||
usage_data = {
|
||||
"api_key": api_key,
|
||||
"api_key": user_api_key,
|
||||
"prompt_tokens": token_usage["prompt_tokens"],
|
||||
"generated_tokens": token_usage["generated_tokens"],
|
||||
"timestamp": datetime.now(),
|
||||
@@ -27,7 +27,7 @@ def gen_token_usage(func):
|
||||
self.token_usage["prompt_tokens"] += count_tokens(message["content"])
|
||||
result = func(self, model, messages, stream, **kwargs)
|
||||
self.token_usage["generated_tokens"] += count_tokens(result)
|
||||
update_token_usage(self.api_key, self.token_usage)
|
||||
update_token_usage(self.user_api_key, self.token_usage)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
@@ -44,6 +44,6 @@ def stream_token_usage(func):
|
||||
yield r
|
||||
for line in batch:
|
||||
self.token_usage["generated_tokens"] += count_tokens(line)
|
||||
update_token_usage(self.api_key, self.token_usage)
|
||||
update_token_usage(self.user_api_key, self.token_usage)
|
||||
|
||||
return wrapper
|
||||
|
||||
12
results.txt
Normal file
12
results.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
Base URL:http://petstore.swagger.io,https://api.example.com
|
||||
Path1: /pets
|
||||
description: None
|
||||
parameters: []
|
||||
methods:
|
||||
get=A paged array of pets
|
||||
post=Null response
|
||||
Path2: /pets/{petId}
|
||||
description: None
|
||||
parameters: []
|
||||
methods:
|
||||
get=Expected response to a valid request
|
||||
Reference in New Issue
Block a user