mirror of
https://github.com/datalab-to/chandra.git
synced 2025-11-29 00:23:12 +00:00
286 lines
8.3 KiB
Python
Executable File
286 lines
8.3 KiB
Python
Executable File
import json
|
|
from pathlib import Path
|
|
from typing import List
|
|
|
|
import click
|
|
|
|
from chandra.input import load_file
|
|
from chandra.model import InferenceManager
|
|
from chandra.model.schema import BatchInputItem
|
|
|
|
|
|
def get_supported_files(input_path: Path) -> List[Path]:
|
|
"""Get list of supported image/PDF files from path."""
|
|
supported_extensions = {
|
|
".pdf",
|
|
".png",
|
|
".jpg",
|
|
".jpeg",
|
|
".gif",
|
|
".webp",
|
|
".tiff",
|
|
".bmp",
|
|
}
|
|
|
|
if input_path.is_file():
|
|
if input_path.suffix.lower() in supported_extensions:
|
|
return [input_path]
|
|
else:
|
|
raise click.BadParameter(f"Unsupported file type: {input_path.suffix}")
|
|
|
|
elif input_path.is_dir():
|
|
files = []
|
|
for ext in supported_extensions:
|
|
files.extend(input_path.glob(f"*{ext}"))
|
|
files.extend(input_path.glob(f"*{ext.upper()}"))
|
|
return sorted(files)
|
|
|
|
else:
|
|
raise click.BadParameter(f"Path does not exist: {input_path}")
|
|
|
|
|
|
def save_merged_output(
|
|
output_dir: Path,
|
|
file_name: str,
|
|
results: List,
|
|
save_images: bool = True,
|
|
save_html: bool = True,
|
|
paginate_output: bool = False,
|
|
):
|
|
"""Save merged OCR results for all pages to output directory."""
|
|
# Create subfolder for this file
|
|
safe_name = Path(file_name).stem
|
|
file_output_dir = output_dir / safe_name
|
|
file_output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Merge all pages
|
|
all_markdown = []
|
|
all_html = []
|
|
all_metadata = []
|
|
total_tokens = 0
|
|
total_chunks = 0
|
|
total_images = 0
|
|
|
|
# Process each page result
|
|
for page_num, result in enumerate(results):
|
|
# Add page separator for multi-page documents
|
|
if page_num > 0 and paginate_output:
|
|
all_markdown.append(f"\n\n{page_num}" + "-" * 48 + "\n\n")
|
|
all_html.append(f"\n\n<!-- Page {page_num + 1} -->\n\n")
|
|
|
|
all_markdown.append(result.markdown)
|
|
all_html.append(result.html)
|
|
|
|
# Accumulate metadata
|
|
total_tokens += result.token_count
|
|
total_chunks += len(result.chunks)
|
|
total_images += len(result.images)
|
|
|
|
page_metadata = {
|
|
"page_num": page_num,
|
|
"page_box": result.page_box,
|
|
"token_count": result.token_count,
|
|
"num_chunks": len(result.chunks),
|
|
"num_images": len(result.images),
|
|
}
|
|
all_metadata.append(page_metadata)
|
|
|
|
# Save extracted images if requested
|
|
if save_images and result.images:
|
|
images_dir = file_output_dir / "images"
|
|
images_dir.mkdir(exist_ok=True)
|
|
|
|
for img_name, pil_image in result.images.items():
|
|
img_path = images_dir / img_name
|
|
pil_image.save(img_path)
|
|
|
|
# Save merged markdown
|
|
markdown_path = file_output_dir / f"{safe_name}.md"
|
|
with open(markdown_path, "w", encoding="utf-8") as f:
|
|
f.write("".join(all_markdown))
|
|
|
|
# Save merged HTML if requested
|
|
if save_html:
|
|
html_path = file_output_dir / f"{safe_name}.html"
|
|
with open(html_path, "w", encoding="utf-8") as f:
|
|
f.write("".join(all_html))
|
|
|
|
# Save combined metadata
|
|
metadata = {
|
|
"file_name": file_name,
|
|
"num_pages": len(results),
|
|
"total_token_count": total_tokens,
|
|
"total_chunks": total_chunks,
|
|
"total_images": total_images,
|
|
"pages": all_metadata,
|
|
}
|
|
metadata_path = file_output_dir / f"{safe_name}_metadata.json"
|
|
with open(metadata_path, "w", encoding="utf-8") as f:
|
|
json.dump(metadata, f, indent=2)
|
|
|
|
click.echo(f" Saved: {markdown_path} ({len(results)} page(s))")
|
|
|
|
|
|
@click.command()
|
|
@click.argument("input_path", type=click.Path(exists=True, path_type=Path))
|
|
@click.argument("output_path", type=click.Path(path_type=Path))
|
|
@click.option(
|
|
"--method",
|
|
type=click.Choice(["hf", "vllm"], case_sensitive=False),
|
|
default="vllm",
|
|
help="Inference method: 'hf' for local model, 'vllm' for vLLM server.",
|
|
)
|
|
@click.option(
|
|
"--page-range",
|
|
type=str,
|
|
default=None,
|
|
help="Page range for PDFs (e.g., '1-5,7,9-12'). Only applicable to PDF files.",
|
|
)
|
|
@click.option(
|
|
"--max-output-tokens",
|
|
type=int,
|
|
default=None,
|
|
help="Maximum number of output tokens per page.",
|
|
)
|
|
@click.option(
|
|
"--max-workers",
|
|
type=int,
|
|
default=None,
|
|
help="Maximum number of parallel workers for vLLM inference.",
|
|
)
|
|
@click.option(
|
|
"--max-retries",
|
|
type=int,
|
|
default=None,
|
|
help="Maximum number of retries for vLLM inference.",
|
|
)
|
|
@click.option(
|
|
"--include-images/--no-images",
|
|
default=True,
|
|
help="Include images in output.",
|
|
)
|
|
@click.option(
|
|
"--include-headers-footers/--no-headers-footers",
|
|
default=False,
|
|
help="Include page headers and footers in output.",
|
|
)
|
|
@click.option(
|
|
"--save-html/--no-html",
|
|
default=True,
|
|
help="Save HTML output files.",
|
|
)
|
|
@click.option(
|
|
"--batch-size",
|
|
type=int,
|
|
default=1,
|
|
help="Number of pages to process in a batch.",
|
|
)
|
|
@click.option(
|
|
"--paginate_output",
|
|
is_flag=True,
|
|
default=False,
|
|
)
|
|
def main(
|
|
input_path: Path,
|
|
output_path: Path,
|
|
method: str,
|
|
page_range: str,
|
|
max_output_tokens: int,
|
|
max_workers: int,
|
|
max_retries: int,
|
|
include_images: bool,
|
|
include_headers_footers: bool,
|
|
save_html: bool,
|
|
batch_size: int,
|
|
paginate_output: bool,
|
|
):
|
|
click.echo("Chandra CLI - Starting OCR processing")
|
|
click.echo(f"Input: {input_path}")
|
|
click.echo(f"Output: {output_path}")
|
|
click.echo(f"Method: {method}")
|
|
|
|
# Create output directory
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Load model
|
|
click.echo(f"\nLoading model with method '{method}'...")
|
|
model = InferenceManager(method=method)
|
|
click.echo("Model loaded successfully.")
|
|
|
|
# Get files to process
|
|
files_to_process = get_supported_files(input_path)
|
|
click.echo(f"\nFound {len(files_to_process)} file(s) to process.")
|
|
|
|
if not files_to_process:
|
|
click.echo("No supported files found. Exiting.")
|
|
return
|
|
|
|
# Process each file
|
|
for file_idx, file_path in enumerate(files_to_process, 1):
|
|
click.echo(
|
|
f"\n[{file_idx}/{len(files_to_process)}] Processing: {file_path.name}"
|
|
)
|
|
|
|
try:
|
|
# Load images from file
|
|
config = {"page_range": page_range} if page_range else {}
|
|
images = load_file(str(file_path), config)
|
|
click.echo(f" Loaded {len(images)} page(s)")
|
|
|
|
# Accumulate all results for this document
|
|
all_results = []
|
|
|
|
# Process pages in batches
|
|
for batch_start in range(0, len(images), batch_size):
|
|
batch_end = min(batch_start + batch_size, len(images))
|
|
batch_images = images[batch_start:batch_end]
|
|
|
|
# Create batch input items
|
|
batch = [
|
|
BatchInputItem(image=img, prompt_type="ocr_layout")
|
|
for img in batch_images
|
|
]
|
|
|
|
# Run inference
|
|
click.echo(f" Processing pages {batch_start + 1}-{batch_end}...")
|
|
|
|
# Build kwargs for generate
|
|
generate_kwargs = {
|
|
"include_images": include_images,
|
|
"include_headers_footers": include_headers_footers,
|
|
}
|
|
|
|
if max_output_tokens is not None:
|
|
generate_kwargs["max_output_tokens"] = max_output_tokens
|
|
|
|
if method == "vllm":
|
|
if max_workers is not None:
|
|
generate_kwargs["max_workers"] = max_workers
|
|
if max_retries is not None:
|
|
generate_kwargs["max_retries"] = max_retries
|
|
|
|
results = model.generate(batch, **generate_kwargs)
|
|
all_results.extend(results)
|
|
|
|
# Save merged output for all pages
|
|
save_merged_output(
|
|
output_path,
|
|
file_path.name,
|
|
all_results,
|
|
save_images=include_images,
|
|
save_html=save_html,
|
|
paginate_output=paginate_output,
|
|
)
|
|
|
|
click.echo(f" Completed: {file_path.name}")
|
|
|
|
except Exception as e:
|
|
click.echo(f" Error processing {file_path.name}: {e}", err=True)
|
|
continue
|
|
|
|
click.echo(f"\nProcessing complete. Results saved to: {output_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|