Merge branch 'arc53:main' into Fixes-#1260

This commit is contained in:
Niharika Goulikar
2024-11-17 16:18:11 +05:30
committed by GitHub
15 changed files with 194 additions and 30 deletions

View File

@@ -241,6 +241,7 @@ def complete_stream(
yield f"data: {data}\n\n"
except Exception as e:
print("\033[91merr", str(e), file=sys.stderr)
traceback.print_exc()
data = json.dumps(
{
"type": "error",

View File

@@ -358,7 +358,7 @@ class UploadFile(Resource):
for file in files:
filename = secure_filename(file.filename)
file.save(os.path.join(temp_dir, filename))
print(f"Saved file: {filename}")
zip_path = shutil.make_archive(
base_name=os.path.join(save_dir, job_name),
format="zip",
@@ -366,6 +366,26 @@ class UploadFile(Resource):
)
final_filename = os.path.basename(zip_path)
shutil.rmtree(temp_dir)
task = ingest.delay(
settings.UPLOAD_FOLDER,
[
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
],
job_name,
final_filename,
user,
)
else:
file = files[0]
final_filename = secure_filename(file.filename)
@@ -392,9 +412,10 @@ class UploadFile(Resource):
final_filename,
user,
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
except Exception as err:
print(f"Error: {err}")
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@@ -465,6 +486,11 @@ class TaskStatus(Resource):
task = celery.AsyncResult(task_id)
task_meta = task.info
print(f"Task status: {task.status}")
if not isinstance(
task_meta, (dict, list, str, int, float, bool, type(None))
):
task_meta = str(task_meta) # Convert to a string representation
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)

View File

@@ -0,0 +1,48 @@
from application.llm.base import BaseLLM
class GoogleLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
def _clean_messages_google(self, messages):
return [
{
"role": "model" if message["role"] == "system" else message["role"],
"parts": [message["content"]],
}
for message in messages[1:]
]
def _raw_gen(
self,
baseself,
model,
messages,
stream=False,
**kwargs
):
import google.generativeai as genai
genai.configure(api_key=self.api_key)
model = genai.GenerativeModel(model, system_instruction=messages[0]["content"])
response = model.generate_content(self._clean_messages_google(messages))
return response.text
def _raw_gen_stream(
self,
baseself,
model,
messages,
stream=True,
**kwargs
):
import google.generativeai as genai
genai.configure(api_key=self.api_key)
model = genai.GenerativeModel(model, system_instruction=messages[0]["content"])
response = model.generate_content(self._clean_messages_google(messages), stream=True)
for line in response:
if line.text is not None:
yield line.text

View File

@@ -6,6 +6,7 @@ from application.llm.llama_cpp import LlamaCpp
from application.llm.anthropic import AnthropicLLM
from application.llm.docsgpt_provider import DocsGPTAPILLM
from application.llm.premai import PremAILLM
from application.llm.google_ai import GoogleLLM
class LLMCreator:
@@ -18,7 +19,8 @@ class LLMCreator:
"anthropic": AnthropicLLM,
"docsgpt": DocsGPTAPILLM,
"premai": PremAILLM,
"groq": GroqLLM
"groq": GroqLLM,
"google": GoogleLLM
}
@classmethod

View File

@@ -1,10 +1,19 @@
from application.parser.remote.base import BaseRemote
from langchain_community.document_loaders import RedditPostsLoader
import json
class RedditPostsLoaderRemote(BaseRemote):
def load_data(self, inputs):
data = eval(inputs)
try:
data = json.loads(inputs)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON input: {e}")
required_fields = ["client_id", "client_secret", "user_agent", "search_queries"]
missing_fields = [field for field in required_fields if field not in data]
if missing_fields:
raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")
client_id = data.get("client_id")
client_secret = data.get("client_secret")
user_agent = data.get("user_agent")

View File

@@ -45,7 +45,6 @@ class ClassicRAG(BaseRetriever):
settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY
)
docs_temp = docsearch.search(self.question, k=self.chunks)
print(docs_temp)
docs = [
{
"title": i.metadata.get(
@@ -60,8 +59,6 @@ class ClassicRAG(BaseRetriever):
}
for i in docs_temp
]
if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]]
return docs