fix: connection aborted in WebBaseLoader

This commit is contained in:
Siddhant Rai
2024-05-03 18:25:01 +05:30
parent 7eaa32d85f
commit aa670efe3a
4 changed files with 123 additions and 80 deletions

View File

@@ -36,6 +36,7 @@ current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
"""
Recursively extract zip files with a limit on recursion depth.
@@ -50,7 +51,7 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
print(f"Reached maximum recursion depth of {max_depth}")
return
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_to)
os.remove(zip_path) # Remove the zip file after extracting
@@ -96,7 +97,6 @@ def ingest_worker(self, directory, formats, name_job, filename, user):
full_path = os.path.join(directory, user, name_job)
import sys
print(full_path, file=sys.stderr)
# check if API_URL env variable is set
file_data = {"name": name_job, "file": filename, "user": user}
@@ -114,7 +114,9 @@ def ingest_worker(self, directory, formats, name_job, filename, user):
# check if file is .zip and extract it
if filename.endswith(".zip"):
extract_zip_recursive(os.path.join(full_path, filename), full_path, 0, recursion_depth)
extract_zip_recursive(
os.path.join(full_path, filename), full_path, 0, recursion_depth
)
self.update_state(state="PROGRESS", meta={"current": 1})
@@ -176,7 +178,6 @@ def ingest_worker(self, directory, formats, name_job, filename, user):
def remote_worker(self, source_data, name_job, user, loader, directory="temp"):
# sample = False
token_check = True
min_tokens = 150
max_tokens = 1250
@@ -184,12 +185,8 @@ def remote_worker(self, source_data, name_job, user, loader, directory="temp"):
if not os.path.exists(full_path):
os.makedirs(full_path)
self.update_state(state="PROGRESS", meta={"current": 1})
# source_data {"data": [url]} for url type task just urls
# Use RemoteCreator to load data from URL
remote_loader = RemoteCreator.create_loader(loader)
raw_docs = remote_loader.load_data(source_data)
@@ -201,7 +198,6 @@ def remote_worker(self, source_data, name_job, user, loader, directory="temp"):
)
# docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
call_openai_api(docs, full_path, self)
self.update_state(state="PROGRESS", meta={"current": 100})