mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-01-20 05:50:58 +00:00
refactor and deps (#2184)
This commit is contained in:
@@ -148,9 +148,12 @@ class ConversationService:
|
||||
]
|
||||
|
||||
completion = llm.gen(
|
||||
model=model_id, messages=messages_summary, max_tokens=30
|
||||
model=model_id, messages=messages_summary, max_tokens=500
|
||||
)
|
||||
|
||||
if not completion or not completion.strip():
|
||||
completion = question[:50] if question else "New Conversation"
|
||||
|
||||
conversation_data = {
|
||||
"user": user_id,
|
||||
"date": current_time,
|
||||
|
||||
@@ -1,75 +1,19 @@
|
||||
import json
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
from application.llm.openai import OpenAILLM
|
||||
|
||||
DOCSGPT_API_KEY = "sk-docsgpt-public"
|
||||
DOCSGPT_BASE_URL = "https://oai.arc53.com"
|
||||
DOCSGPT_MODEL = "docsgpt"
|
||||
|
||||
class DocsGPTAPILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = "sk-docsgpt-public"
|
||||
self.client = OpenAI(api_key=self.api_key, base_url="https://oai.arc53.com")
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _clean_messages_openai(self, messages):
|
||||
cleaned_messages = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
if role == "model":
|
||||
role = "assistant"
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
cleaned_messages.append({"role": role, "content": content})
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if "text" in item:
|
||||
cleaned_messages.append(
|
||||
{"role": role, "content": item["text"]}
|
||||
)
|
||||
elif "function_call" in item:
|
||||
cleaned_args = self._remove_null_values(
|
||||
item["function_call"]["args"]
|
||||
)
|
||||
tool_call = {
|
||||
"id": item["function_call"]["call_id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": item["function_call"]["name"],
|
||||
"arguments": json.dumps(cleaned_args),
|
||||
},
|
||||
}
|
||||
cleaned_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [tool_call],
|
||||
}
|
||||
)
|
||||
elif "function_response" in item:
|
||||
cleaned_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": item["function_response"][
|
||||
"call_id"
|
||||
],
|
||||
"content": json.dumps(
|
||||
item["function_response"]["response"]["result"]
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content dictionary format: {item}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
return cleaned_messages
|
||||
class DocsGPTAPILLM(OpenAILLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=DOCSGPT_API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
base_url=DOCSGPT_BASE_URL,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
@@ -79,23 +23,19 @@ class DocsGPTAPILLM(BaseLLM):
|
||||
stream=False,
|
||||
tools=None,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
response_format=None,
|
||||
**kwargs,
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model="docsgpt",
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
return response.choices[0]
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model="docsgpt", messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
return super()._raw_gen(
|
||||
baseself,
|
||||
DOCSGPT_MODEL,
|
||||
messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
engine=engine,
|
||||
response_format=response_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _raw_gen_stream(
|
||||
self,
|
||||
@@ -105,34 +45,16 @@ class DocsGPTAPILLM(BaseLLM):
|
||||
stream=True,
|
||||
tools=None,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
response_format=None,
|
||||
**kwargs,
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model="docsgpt",
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model="docsgpt", messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
try:
|
||||
for line in response:
|
||||
if (
|
||||
len(line.choices) > 0
|
||||
and line.choices[0].delta.content is not None
|
||||
and len(line.choices[0].delta.content) > 0
|
||||
):
|
||||
yield line.choices[0].delta.content
|
||||
elif len(line.choices) > 0:
|
||||
yield line.choices[0]
|
||||
finally:
|
||||
if hasattr(response, "close"):
|
||||
response.close()
|
||||
|
||||
def _supports_tools(self):
|
||||
return True
|
||||
return super()._raw_gen_stream(
|
||||
baseself,
|
||||
DOCSGPT_MODEL,
|
||||
messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
engine=engine,
|
||||
response_format=response_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1,37 +1,15 @@
|
||||
from openai import OpenAI
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
from application.llm.openai import OpenAILLM
|
||||
|
||||
GROQ_BASE_URL = "https://api.groq.com/openai/v1"
|
||||
|
||||
|
||||
class GroqLLM(BaseLLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key or settings.GROQ_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key, base_url="https://api.groq.com/openai/v1"
|
||||
class GroqLLM(OpenAILLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=api_key or settings.GROQ_API_KEY or settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
base_url=base_url or GROQ_BASE_URL,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, tools=tools, **kwargs
|
||||
)
|
||||
return response.choices[0]
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
def _raw_gen_stream(
|
||||
self, baseself, model, messages, stream=True, tools=None, **kwargs
|
||||
):
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
for line in response:
|
||||
if line.choices[0].delta.content is not None:
|
||||
yield line.choices[0].delta.content
|
||||
|
||||
@@ -833,7 +833,10 @@ class LLMHandler(ABC):
|
||||
if call.name:
|
||||
existing.name = call.name
|
||||
if call.arguments:
|
||||
existing.arguments += call.arguments
|
||||
if existing.arguments is None:
|
||||
existing.arguments = call.arguments
|
||||
else:
|
||||
existing.arguments += call.arguments
|
||||
# Preserve thought_signature for Google Gemini 3 models
|
||||
if call.thought_signature:
|
||||
existing.thought_signature = call.thought_signature
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
|
||||
class HuggingFaceLLM(BaseLLM):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key=None,
|
||||
user_api_key=None,
|
||||
llm_name="Arc53/DocsGPT-7B",
|
||||
q=False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
global hf
|
||||
|
||||
from langchain.llms import HuggingFacePipeline
|
||||
|
||||
if q:
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
pipeline,
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(llm_name)
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
llm_name, quantization_config=bnb_config
|
||||
)
|
||||
else:
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(llm_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(llm_name)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
pipe = pipeline(
|
||||
"text-generation",
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
max_new_tokens=2000,
|
||||
device_map="auto",
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
hf = HuggingFacePipeline(pipeline=pipe)
|
||||
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
|
||||
result = hf(prompt)
|
||||
|
||||
return result.content
|
||||
|
||||
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
|
||||
|
||||
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")
|
||||
@@ -4,7 +4,6 @@ from application.llm.anthropic import AnthropicLLM
|
||||
from application.llm.docsgpt_provider import DocsGPTAPILLM
|
||||
from application.llm.google_ai import GoogleLLM
|
||||
from application.llm.groq import GroqLLM
|
||||
from application.llm.huggingface import HuggingFaceLLM
|
||||
from application.llm.llama_cpp import LlamaCpp
|
||||
from application.llm.novita import NovitaLLM
|
||||
from application.llm.openai import AzureOpenAILLM, OpenAILLM
|
||||
@@ -19,7 +18,6 @@ class LLMCreator:
|
||||
"openai": OpenAILLM,
|
||||
"azure_openai": AzureOpenAILLM,
|
||||
"sagemaker": SagemakerAPILLM,
|
||||
"huggingface": HuggingFaceLLM,
|
||||
"llama.cpp": LlamaCpp,
|
||||
"anthropic": AnthropicLLM,
|
||||
"docsgpt": DocsGPTAPILLM,
|
||||
|
||||
@@ -1,32 +1,15 @@
|
||||
from application.llm.base import BaseLLM
|
||||
from openai import OpenAI
|
||||
from application.core.settings import settings
|
||||
from application.llm.openai import OpenAILLM
|
||||
|
||||
NOVITA_BASE_URL = "https://api.novita.ai/v3/openai"
|
||||
|
||||
|
||||
class NovitaLLM(BaseLLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client = OpenAI(api_key=api_key, base_url="https://api.novita.ai/v3/openai")
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, tools=tools, **kwargs
|
||||
)
|
||||
return response.choices[0]
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
def _raw_gen_stream(
|
||||
self, baseself, model, messages, stream=True, tools=None, **kwargs
|
||||
):
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
class NovitaLLM(OpenAILLM):
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=api_key or settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
base_url=base_url or NOVITA_BASE_URL,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
for line in response:
|
||||
if line.choices[0].delta.content is not None:
|
||||
yield line.choices[0].delta.content
|
||||
|
||||
@@ -127,6 +127,7 @@ class OpenAILLM(BaseLLM):
|
||||
**kwargs,
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
logging.info(f"Cleaned messages: {messages}")
|
||||
|
||||
# Convert max_tokens to max_completion_tokens for newer models
|
||||
if "max_tokens" in kwargs:
|
||||
@@ -144,7 +145,7 @@ class OpenAILLM(BaseLLM):
|
||||
if response_format:
|
||||
request_params["response_format"] = response_format
|
||||
response = self.client.chat.completions.create(**request_params)
|
||||
|
||||
logging.info(f"OpenAI response: {response}")
|
||||
if tools:
|
||||
return response.choices[0]
|
||||
else:
|
||||
@@ -162,6 +163,7 @@ class OpenAILLM(BaseLLM):
|
||||
**kwargs,
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
logging.info(f"Cleaned messages: {messages}")
|
||||
|
||||
# Convert max_tokens to max_completion_tokens for newer models
|
||||
if "max_tokens" in kwargs:
|
||||
@@ -182,6 +184,7 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
try:
|
||||
for line in response:
|
||||
logging.debug(f"OpenAI stream line: {line}")
|
||||
if (
|
||||
len(line.choices) > 0
|
||||
and line.choices[0].delta.content is not None
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.docstore.document import Document as LCDocument
|
||||
from langchain_core.documents import Document as LCDocument
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.docstore.document import Document as LCDocument
|
||||
from langchain_core.documents import Document as LCDocument
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Base schema for readers."""
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain.docstore.document import Document as LCDocument
|
||||
from langchain_core.documents import Document as LCDocument
|
||||
from application.parser.schema.schema import BaseDocument
|
||||
|
||||
|
||||
|
||||
@@ -1,91 +1,91 @@
|
||||
anthropic==0.49.0
|
||||
boto3==1.38.18
|
||||
beautifulsoup4==4.13.4
|
||||
celery==5.4.0
|
||||
cryptography==42.0.8
|
||||
anthropic==0.75.0
|
||||
boto3==1.42.6
|
||||
beautifulsoup4==4.14.3
|
||||
celery==5.6.0
|
||||
cryptography==46.0.3
|
||||
dataclasses-json==0.6.7
|
||||
docx2txt==0.8
|
||||
duckduckgo-search==7.5.2
|
||||
duckduckgo-search==8.1.1
|
||||
ebooklib==0.18
|
||||
escodegen==1.0.11
|
||||
esprima==4.0.1
|
||||
esutils==1.0.1
|
||||
elevenlabs==2.17.0
|
||||
Flask==3.1.1
|
||||
faiss-cpu==1.9.0.post1
|
||||
fastmcp==2.11.0
|
||||
flask-restx==1.3.0
|
||||
google-genai==1.49.0
|
||||
google-api-python-client==2.179.0
|
||||
elevenlabs==2.26.1
|
||||
Flask==3.1.2
|
||||
faiss-cpu==1.13.1
|
||||
fastmcp==2.13.3
|
||||
flask-restx==1.3.2
|
||||
google-genai==1.54.0
|
||||
google-api-python-client==2.187.0
|
||||
google-auth-httplib2==0.2.0
|
||||
google-auth-oauthlib==1.2.2
|
||||
gTTS==2.5.4
|
||||
gunicorn==23.0.0
|
||||
javalang==0.13.0
|
||||
jinja2==3.1.6
|
||||
jiter==0.8.2
|
||||
jiter==0.12.0
|
||||
jmespath==1.0.1
|
||||
joblib==1.4.2
|
||||
jsonpatch==1.33
|
||||
jsonpointer==3.0.0
|
||||
kombu==5.4.2
|
||||
langchain==0.3.20
|
||||
langchain-community==0.3.19
|
||||
langchain-core==0.3.59
|
||||
langchain-openai==0.3.16
|
||||
langchain-text-splitters==0.3.8
|
||||
langsmith==0.3.42
|
||||
kombu==5.6.1
|
||||
langchain==1.1.3
|
||||
langchain-community==0.4.1
|
||||
langchain-core==1.1.3
|
||||
langchain-openai==1.1.1
|
||||
langchain-text-splitters==1.0.0
|
||||
langsmith==0.4.58
|
||||
lazy-object-proxy==1.10.0
|
||||
lxml==5.3.1
|
||||
lxml==6.0.2
|
||||
markupsafe==3.0.2
|
||||
marshmallow==3.26.1
|
||||
mpmath==1.3.0
|
||||
multidict==6.4.3
|
||||
multidict==6.7.0
|
||||
mypy-extensions==1.0.0
|
||||
networkx==3.4.2
|
||||
networkx==3.6.1
|
||||
numpy==2.2.1
|
||||
openai==1.78.1
|
||||
openai==2.9.0
|
||||
openapi3-parser==1.1.21
|
||||
orjson==3.10.14
|
||||
orjson==3.11.5
|
||||
packaging==24.2
|
||||
pandas==2.2.3
|
||||
openpyxl==3.1.5
|
||||
pathable==0.4.4
|
||||
pillow==11.1.0
|
||||
pillow==12.0.0
|
||||
portalocker>=2.7.0,<3.0.0
|
||||
prance==23.6.21.0
|
||||
prompt-toolkit==3.0.51
|
||||
protobuf==5.29.3
|
||||
psycopg2-binary==2.9.10
|
||||
psycopg2-binary==2.9.11
|
||||
py==1.11.0
|
||||
pydantic
|
||||
pydantic-core
|
||||
pydantic-settings
|
||||
pymongo==4.11.3
|
||||
pypdf==5.5.0
|
||||
pymongo==4.15.5
|
||||
pypdf==6.4.1
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv
|
||||
python-jose==3.4.0
|
||||
python-pptx==1.0.2
|
||||
redis==5.2.1
|
||||
redis==7.1.0
|
||||
referencing>=0.28.0,<0.31.0
|
||||
regex==2024.11.6
|
||||
requests==2.32.3
|
||||
regex==2025.11.3
|
||||
requests==2.32.5
|
||||
retry==0.9.2
|
||||
sentence-transformers==3.3.1
|
||||
tiktoken==0.8.0
|
||||
tokenizers==0.21.0
|
||||
sentence-transformers==5.1.2
|
||||
tiktoken==0.12.0
|
||||
tokenizers==0.22.1
|
||||
torch==2.7.0
|
||||
tqdm==4.67.1
|
||||
transformers==4.51.3
|
||||
typing-extensions==4.12.2
|
||||
transformers==4.57.3
|
||||
typing-extensions==4.15.0
|
||||
typing-inspect==0.9.0
|
||||
tzdata==2024.2
|
||||
urllib3==2.3.0
|
||||
tzdata==2025.2
|
||||
urllib3==2.6.1
|
||||
vine==5.1.0
|
||||
wcwidth==0.2.13
|
||||
werkzeug>=3.1.0,<3.1.2
|
||||
yarl==1.20.0
|
||||
markdownify==1.1.0
|
||||
tldextract==5.1.3
|
||||
websockets==14.1
|
||||
werkzeug>=3.1.0
|
||||
yarl==1.22.0
|
||||
markdownify==1.2.2
|
||||
tldextract==5.3.0
|
||||
websockets==15.0.1
|
||||
@@ -2,7 +2,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
from application.parser.remote.crawler_loader import CrawlerLoader
|
||||
from application.parser.schema.base import Document
|
||||
from langchain.docstore.document import Document as LCDocument
|
||||
from langchain_core.documents import Document as LCDocument
|
||||
|
||||
|
||||
class DummyResponse:
|
||||
|
||||
@@ -4,7 +4,7 @@ from urllib.parse import urlparse
|
||||
|
||||
from application.parser.remote.web_loader import WebLoader, headers
|
||||
from application.parser.schema.base import Document
|
||||
from langchain.docstore.document import Document as LCDocument
|
||||
from langchain_core.documents import Document as LCDocument
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
Reference in New Issue
Block a user