Files
chandra/chandra_cli.py
2025-10-20 15:11:12 -04:00

286 lines
8.3 KiB
Python
Executable File

import json
from pathlib import Path
from typing import List
import click
from chandra.input import load_file
from chandra.model import InferenceManager
from chandra.model.schema import BatchInputItem
def get_supported_files(input_path: Path) -> List[Path]:
"""Get list of supported image/PDF files from path."""
supported_extensions = {
".pdf",
".png",
".jpg",
".jpeg",
".gif",
".webp",
".tiff",
".bmp",
}
if input_path.is_file():
if input_path.suffix.lower() in supported_extensions:
return [input_path]
else:
raise click.BadParameter(f"Unsupported file type: {input_path.suffix}")
elif input_path.is_dir():
files = []
for ext in supported_extensions:
files.extend(input_path.glob(f"*{ext}"))
files.extend(input_path.glob(f"*{ext.upper()}"))
return sorted(files)
else:
raise click.BadParameter(f"Path does not exist: {input_path}")
def save_merged_output(
output_dir: Path,
file_name: str,
results: List,
save_images: bool = True,
save_html: bool = True,
paginate_output: bool = False,
):
"""Save merged OCR results for all pages to output directory."""
# Create subfolder for this file
safe_name = Path(file_name).stem
file_output_dir = output_dir / safe_name
file_output_dir.mkdir(parents=True, exist_ok=True)
# Merge all pages
all_markdown = []
all_html = []
all_metadata = []
total_tokens = 0
total_chunks = 0
total_images = 0
# Process each page result
for page_num, result in enumerate(results):
# Add page separator for multi-page documents
if page_num > 0 and paginate_output:
all_markdown.append(f"\n\n{page_num}" + "-" * 48 + "\n\n")
all_html.append(f"\n\n<!-- Page {page_num + 1} -->\n\n")
all_markdown.append(result.markdown)
all_html.append(result.html)
# Accumulate metadata
total_tokens += result.token_count
total_chunks += len(result.chunks)
total_images += len(result.images)
page_metadata = {
"page_num": page_num,
"page_box": result.page_box,
"token_count": result.token_count,
"num_chunks": len(result.chunks),
"num_images": len(result.images),
}
all_metadata.append(page_metadata)
# Save extracted images if requested
if save_images and result.images:
images_dir = file_output_dir / "images"
images_dir.mkdir(exist_ok=True)
for img_name, pil_image in result.images.items():
img_path = images_dir / img_name
pil_image.save(img_path)
# Save merged markdown
markdown_path = file_output_dir / f"{safe_name}.md"
with open(markdown_path, "w", encoding="utf-8") as f:
f.write("".join(all_markdown))
# Save merged HTML if requested
if save_html:
html_path = file_output_dir / f"{safe_name}.html"
with open(html_path, "w", encoding="utf-8") as f:
f.write("".join(all_html))
# Save combined metadata
metadata = {
"file_name": file_name,
"num_pages": len(results),
"total_token_count": total_tokens,
"total_chunks": total_chunks,
"total_images": total_images,
"pages": all_metadata,
}
metadata_path = file_output_dir / f"{safe_name}_metadata.json"
with open(metadata_path, "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2)
click.echo(f" Saved: {markdown_path} ({len(results)} page(s))")
@click.command()
@click.argument("input_path", type=click.Path(exists=True, path_type=Path))
@click.argument("output_path", type=click.Path(path_type=Path))
@click.option(
"--method",
type=click.Choice(["hf", "vllm"], case_sensitive=False),
default="vllm",
help="Inference method: 'hf' for local model, 'vllm' for vLLM server.",
)
@click.option(
"--page-range",
type=str,
default=None,
help="Page range for PDFs (e.g., '1-5,7,9-12'). Only applicable to PDF files.",
)
@click.option(
"--max-output-tokens",
type=int,
default=None,
help="Maximum number of output tokens per page.",
)
@click.option(
"--max-workers",
type=int,
default=None,
help="Maximum number of parallel workers for vLLM inference.",
)
@click.option(
"--max-retries",
type=int,
default=None,
help="Maximum number of retries for vLLM inference.",
)
@click.option(
"--include-images/--no-images",
default=True,
help="Include images in output.",
)
@click.option(
"--include-headers-footers/--no-headers-footers",
default=False,
help="Include page headers and footers in output.",
)
@click.option(
"--save-html/--no-html",
default=True,
help="Save HTML output files.",
)
@click.option(
"--batch-size",
type=int,
default=1,
help="Number of pages to process in a batch.",
)
@click.option(
"--paginate_output",
is_flag=True,
default=False,
)
def main(
input_path: Path,
output_path: Path,
method: str,
page_range: str,
max_output_tokens: int,
max_workers: int,
max_retries: int,
include_images: bool,
include_headers_footers: bool,
save_html: bool,
batch_size: int,
paginate_output: bool,
):
click.echo("Chandra CLI - Starting OCR processing")
click.echo(f"Input: {input_path}")
click.echo(f"Output: {output_path}")
click.echo(f"Method: {method}")
# Create output directory
output_path.mkdir(parents=True, exist_ok=True)
# Load model
click.echo(f"\nLoading model with method '{method}'...")
model = InferenceManager(method=method)
click.echo("Model loaded successfully.")
# Get files to process
files_to_process = get_supported_files(input_path)
click.echo(f"\nFound {len(files_to_process)} file(s) to process.")
if not files_to_process:
click.echo("No supported files found. Exiting.")
return
# Process each file
for file_idx, file_path in enumerate(files_to_process, 1):
click.echo(
f"\n[{file_idx}/{len(files_to_process)}] Processing: {file_path.name}"
)
try:
# Load images from file
config = {"page_range": page_range} if page_range else {}
images = load_file(str(file_path), config)
click.echo(f" Loaded {len(images)} page(s)")
# Accumulate all results for this document
all_results = []
# Process pages in batches
for batch_start in range(0, len(images), batch_size):
batch_end = min(batch_start + batch_size, len(images))
batch_images = images[batch_start:batch_end]
# Create batch input items
batch = [
BatchInputItem(image=img, prompt_type="ocr_layout")
for img in batch_images
]
# Run inference
click.echo(f" Processing pages {batch_start + 1}-{batch_end}...")
# Build kwargs for generate
generate_kwargs = {
"include_images": include_images,
"include_headers_footers": include_headers_footers,
}
if max_output_tokens is not None:
generate_kwargs["max_output_tokens"] = max_output_tokens
if method == "vllm":
if max_workers is not None:
generate_kwargs["max_workers"] = max_workers
if max_retries is not None:
generate_kwargs["max_retries"] = max_retries
results = model.generate(batch, **generate_kwargs)
all_results.extend(results)
# Save merged output for all pages
save_merged_output(
output_path,
file_path.name,
all_results,
save_images=include_images,
save_html=save_html,
paginate_output=paginate_output,
)
click.echo(f" Completed: {file_path.name}")
except Exception as e:
click.echo(f" Error processing {file_path.name}: {e}", err=True)
continue
click.echo(f"\nProcessing complete. Results saved to: {output_path}")
if __name__ == "__main__":
main()