Files
DocsGPT/application/utils.py
Alex f61d112cea feat: process pdfs synthetically im model does not support file natively (#2263)
* feat: process pdfs synthetically im model does not support file natively

* fix: small code optimisations
2026-01-15 02:30:33 +02:00

357 lines
12 KiB
Python

import base64
import hashlib
import io
import logging
import os
import re
import uuid
from typing import List
import tiktoken
from flask import jsonify, make_response
from werkzeug.utils import secure_filename
from application.core.model_utils import get_token_limit
from application.core.settings import settings
logger = logging.getLogger(__name__)
_encoding = None
def get_encoding():
global _encoding
if _encoding is None:
_encoding = tiktoken.get_encoding("cl100k_base")
return _encoding
def get_gpt_model() -> str:
"""Get GPT model based on provider"""
model_map = {
"openai": "gpt-4o-mini",
"anthropic": "claude-2",
"groq": "llama3-8b-8192",
"novita": "deepseek/deepseek-r1",
}
return settings.LLM_NAME or model_map.get(settings.LLM_PROVIDER, "")
def safe_filename(filename):
"""Create safe filename, preserving extension. Handles non-Latin characters."""
if not filename:
return str(uuid.uuid4())
_, extension = os.path.splitext(filename)
safe_name = secure_filename(filename)
# If secure_filename returns just the extension or an empty string
if not safe_name or safe_name == extension.lstrip("."):
return f"{str(uuid.uuid4())}{extension}"
return safe_name
def num_tokens_from_string(string: str) -> int:
encoding = get_encoding()
if isinstance(string, str):
num_tokens = len(encoding.encode(string))
return num_tokens
else:
return 0
def num_tokens_from_object_or_list(thing):
if isinstance(thing, list):
return sum([num_tokens_from_object_or_list(x) for x in thing])
elif isinstance(thing, dict):
return sum([num_tokens_from_object_or_list(x) for x in thing.values()])
elif isinstance(thing, str):
return num_tokens_from_string(thing)
else:
return 0
def count_tokens_docs(docs):
docs_content = ""
for doc in docs:
docs_content += doc.page_content
tokens = num_tokens_from_string(docs_content)
return tokens
def calculate_doc_token_budget(
model_id: str = "gpt-4o"
) -> int:
total_context = get_token_limit(model_id)
reserved = sum(settings.RESERVED_TOKENS.values())
doc_budget = total_context - reserved
return max(doc_budget, 1000)
def get_missing_fields(data, required_fields):
"""Check for missing required fields. Returns list of missing field names."""
return [field for field in required_fields if field not in data]
def check_required_fields(data, required_fields):
"""Validate required fields. Returns Flask 400 response if validation fails, None otherwise."""
missing_fields = get_missing_fields(data, required_fields)
if missing_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Missing required fields: {', '.join(missing_fields)}",
}
),
400,
)
return None
def get_field_validation_errors(data, required_fields):
"""Check for missing and empty fields. Returns dict with 'missing_fields' and 'empty_fields', or None."""
missing_fields = []
empty_fields = []
for field in required_fields:
if field not in data:
missing_fields.append(field)
elif not data[field]:
empty_fields.append(field)
if missing_fields or empty_fields:
return {"missing_fields": missing_fields, "empty_fields": empty_fields}
return None
def validate_required_fields(data, required_fields):
"""Validate required fields (must exist and be non-empty). Returns Flask 400 response if validation fails, None otherwise."""
errors_dict = get_field_validation_errors(data, required_fields)
if errors_dict:
errors = []
if errors_dict["missing_fields"]:
errors.append(
f"Missing required fields: {', '.join(errors_dict['missing_fields'])}"
)
if errors_dict["empty_fields"]:
errors.append(
f"Empty values in required fields: {', '.join(errors_dict['empty_fields'])}"
)
return make_response(
jsonify({"success": False, "message": " | ".join(errors)}), 400
)
return None
def get_hash(data):
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
def limit_chat_history(history, max_token_limit=None, model_id="docsgpt-local"):
"""Limit chat history to fit within token limit."""
model_token_limit = get_token_limit(model_id)
max_token_limit = (
max_token_limit
if max_token_limit and max_token_limit < model_token_limit
else model_token_limit
)
if not history:
return []
trimmed_history = []
tokens_current_history = 0
for message in reversed(history):
tokens_batch = 0
if "prompt" in message and "response" in message:
tokens_batch += num_tokens_from_string(message["prompt"])
tokens_batch += num_tokens_from_string(message["response"])
if "tool_calls" in message:
for tool_call in message["tool_calls"]:
tool_call_string = f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}"
tokens_batch += num_tokens_from_string(tool_call_string)
if tokens_current_history + tokens_batch < max_token_limit:
tokens_current_history += tokens_batch
trimmed_history.insert(0, message)
else:
break
return trimmed_history
def validate_function_name(function_name):
"""Validate function name matches allowed pattern (alphanumeric, underscore, hyphen)."""
if not re.match(r"^[a-zA-Z0-9_-]+$", function_name):
return False
return True
def generate_image_url(image_path):
if isinstance(image_path, str) and (
image_path.startswith("http://") or image_path.startswith("https://")
):
return image_path
strategy = getattr(settings, "URL_STRATEGY", "backend")
if strategy == "s3":
bucket_name = getattr(settings, "S3_BUCKET_NAME", "docsgpt-test-bucket")
region_name = getattr(settings, "SAGEMAKER_REGION", "eu-central-1")
return f"https://{bucket_name}.s3.{region_name}.amazonaws.com/{image_path}"
else:
base_url = getattr(settings, "API_URL", "http://localhost:7091")
return f"{base_url}/api/images/{image_path}"
def calculate_compression_threshold(
model_id: str, threshold_percentage: float = 0.8
) -> int:
"""
Calculate token threshold for triggering compression.
Args:
model_id: Model identifier
threshold_percentage: Percentage of context window (default 80%)
Returns:
Token count threshold
"""
total_context = get_token_limit(model_id)
threshold = int(total_context * threshold_percentage)
return threshold
def convert_pdf_to_images(
file_path: str,
storage=None,
max_pages: int = 20,
dpi: int = 150,
image_format: str = "PNG",
) -> List[dict]:
"""
Convert PDF pages to images for LLMs that support images but not PDFs.
This enables "synthetic PDF support" by converting each PDF page to an image
that can be sent to vision-capable LLMs like Claude.
Args:
file_path: Path to the PDF file (can be storage path)
storage: Optional storage instance for retrieving files
max_pages: Maximum number of pages to convert (default 20 to avoid context overflow)
dpi: Resolution for rendering (default 150 for balance of quality/size)
image_format: Output format (PNG recommended for quality)
Returns:
List of dicts with keys:
- 'data': base64-encoded image data
- 'mime_type': MIME type (e.g., 'image/png')
- 'page': Page number (1-indexed)
Raises:
ImportError: If pdf2image is not installed
FileNotFoundError: If file doesn't exist
Exception: If conversion fails
"""
try:
from pdf2image import convert_from_path, convert_from_bytes
except ImportError:
raise ImportError(
"pdf2image is required for PDF-to-image conversion. "
"Install it with: pip install pdf2image\n"
"Also ensure poppler-utils is installed on your system."
)
images_data = []
mime_type = f"image/{image_format.lower()}"
try:
# Get PDF content either from storage or direct file path
if storage and hasattr(storage, "get_file"):
with storage.get_file(file_path) as pdf_file:
pdf_bytes = pdf_file.read()
pil_images = convert_from_bytes(
pdf_bytes,
dpi=dpi,
fmt=image_format.lower(),
first_page=1,
last_page=max_pages,
)
else:
pil_images = convert_from_path(
file_path,
dpi=dpi,
fmt=image_format.lower(),
first_page=1,
last_page=max_pages,
)
for page_num, pil_image in enumerate(pil_images, start=1):
# Convert PIL image to base64
buffer = io.BytesIO()
pil_image.save(buffer, format=image_format)
buffer.seek(0)
base64_data = base64.b64encode(buffer.read()).decode("utf-8")
images_data.append({
"data": base64_data,
"mime_type": mime_type,
"page": page_num,
})
return images_data
except FileNotFoundError:
logger.error(f"PDF file not found: {file_path}")
raise
except Exception as e:
logger.error(f"Error converting PDF to images: {e}", exc_info=True)
raise
def clean_text_for_tts(text: str) -> str:
"""
clean text for Text-to-Speech processing.
"""
# Handle code blocks and links
text = re.sub(r"```mermaid[\s\S]*?```", " flowchart, ", text) ## ```mermaid...```
text = re.sub(r"```[\s\S]*?```", " code block, ", text) ## ```code```
text = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", text) ## [text](url)
text = re.sub(r"!\[([^\]]*)\]\([^\)]+\)", "", text) ## ![alt](url)
# Remove markdown formatting
text = re.sub(r"`([^`]+)`", r"\1", text) ## `code`
text = re.sub(r"\{([^}]*)\}", r" \1 ", text) ## {text}
text = re.sub(r"[{}]", " ", text) ## unmatched {}
text = re.sub(r"\[([^\]]+)\]", r" \1 ", text) ## [text]
text = re.sub(r"[\[\]]", " ", text) ## unmatched []
text = re.sub(r"(\*\*|__)(.*?)\1", r"\2", text) ## **bold** __bold__
text = re.sub(r"(\*|_)(.*?)\1", r"\2", text) ## *italic* _italic_
text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE) ## # headers
text = re.sub(r"^>\s+", "", text, flags=re.MULTILINE) ## > blockquotes
text = re.sub(r"^[\s]*[-\*\+]\s+", "", text, flags=re.MULTILINE) ## - * + lists
text = re.sub(r"^[\s]*\d+\.\s+", "", text, flags=re.MULTILINE) ## 1. numbered lists
text = re.sub(
r"^[\*\-_]{3,}\s*$", "", text, flags=re.MULTILINE
) ## --- *** ___ rules
text = re.sub(r"<[^>]*>", "", text) ## <html> tags
# Remove non-ASCII (emojis, special Unicode)
text = re.sub(r"[^\x20-\x7E\n\r\t]", "", text)
# Replace special sequences
text = re.sub(r"-->", ", ", text) ## -->
text = re.sub(r"<--", ", ", text) ## <--
text = re.sub(r"=>", ", ", text) ## =>
text = re.sub(r"::", " ", text) ## ::
# Normalize whitespace
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text