mirror of
https://github.com/datalab-to/chandra.git
synced 2025-11-29 16:43:11 +00:00
Initial benchmarks
This commit is contained in:
285
chandra_cli.py
Executable file
285
chandra_cli.py
Executable file
@@ -0,0 +1,285 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import click
|
||||
|
||||
from chandra.input import load_file
|
||||
from chandra.model import InferenceManager
|
||||
from chandra.model.schema import BatchInputItem
|
||||
|
||||
|
||||
def get_supported_files(input_path: Path) -> List[Path]:
|
||||
"""Get list of supported image/PDF files from path."""
|
||||
supported_extensions = {
|
||||
".pdf",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".webp",
|
||||
".tiff",
|
||||
".bmp",
|
||||
}
|
||||
|
||||
if input_path.is_file():
|
||||
if input_path.suffix.lower() in supported_extensions:
|
||||
return [input_path]
|
||||
else:
|
||||
raise click.BadParameter(f"Unsupported file type: {input_path.suffix}")
|
||||
|
||||
elif input_path.is_dir():
|
||||
files = []
|
||||
for ext in supported_extensions:
|
||||
files.extend(input_path.glob(f"*{ext}"))
|
||||
files.extend(input_path.glob(f"*{ext.upper()}"))
|
||||
return sorted(files)
|
||||
|
||||
else:
|
||||
raise click.BadParameter(f"Path does not exist: {input_path}")
|
||||
|
||||
|
||||
def save_merged_output(
|
||||
output_dir: Path,
|
||||
file_name: str,
|
||||
results: List,
|
||||
save_images: bool = True,
|
||||
save_html: bool = True,
|
||||
paginate_output: bool = False,
|
||||
):
|
||||
"""Save merged OCR results for all pages to output directory."""
|
||||
# Create subfolder for this file
|
||||
safe_name = Path(file_name).stem
|
||||
file_output_dir = output_dir / safe_name
|
||||
file_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Merge all pages
|
||||
all_markdown = []
|
||||
all_html = []
|
||||
all_metadata = []
|
||||
total_tokens = 0
|
||||
total_chunks = 0
|
||||
total_images = 0
|
||||
|
||||
# Process each page result
|
||||
for page_num, result in enumerate(results):
|
||||
# Add page separator for multi-page documents
|
||||
if page_num > 0 and paginate_output:
|
||||
all_markdown.append(f"\n\n{page_num}" + "-" * 48 + "\n\n")
|
||||
all_html.append(f"\n\n<!-- Page {page_num + 1} -->\n\n")
|
||||
|
||||
all_markdown.append(result.markdown)
|
||||
all_html.append(result.html)
|
||||
|
||||
# Accumulate metadata
|
||||
total_tokens += result.token_count
|
||||
total_chunks += len(result.chunks)
|
||||
total_images += len(result.images)
|
||||
|
||||
page_metadata = {
|
||||
"page_num": page_num,
|
||||
"page_box": result.page_box,
|
||||
"token_count": result.token_count,
|
||||
"num_chunks": len(result.chunks),
|
||||
"num_images": len(result.images),
|
||||
}
|
||||
all_metadata.append(page_metadata)
|
||||
|
||||
# Save extracted images if requested
|
||||
if save_images and result.images:
|
||||
images_dir = file_output_dir / "images"
|
||||
images_dir.mkdir(exist_ok=True)
|
||||
|
||||
for img_name, pil_image in result.images.items():
|
||||
img_path = images_dir / img_name
|
||||
pil_image.save(img_path)
|
||||
|
||||
# Save merged markdown
|
||||
markdown_path = file_output_dir / f"{safe_name}.md"
|
||||
with open(markdown_path, "w", encoding="utf-8") as f:
|
||||
f.write("".join(all_markdown))
|
||||
|
||||
# Save merged HTML if requested
|
||||
if save_html:
|
||||
html_path = file_output_dir / f"{safe_name}.html"
|
||||
with open(html_path, "w", encoding="utf-8") as f:
|
||||
f.write("".join(all_html))
|
||||
|
||||
# Save combined metadata
|
||||
metadata = {
|
||||
"file_name": file_name,
|
||||
"num_pages": len(results),
|
||||
"total_token_count": total_tokens,
|
||||
"total_chunks": total_chunks,
|
||||
"total_images": total_images,
|
||||
"pages": all_metadata,
|
||||
}
|
||||
metadata_path = file_output_dir / f"{safe_name}_metadata.json"
|
||||
with open(metadata_path, "w", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
click.echo(f" Saved: {markdown_path} ({len(results)} page(s))")
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("input_path", type=click.Path(exists=True, path_type=Path))
|
||||
@click.argument("output_path", type=click.Path(path_type=Path))
|
||||
@click.option(
|
||||
"--method",
|
||||
type=click.Choice(["hf", "vllm"], case_sensitive=False),
|
||||
default="vllm",
|
||||
help="Inference method: 'hf' for local model, 'vllm' for vLLM server.",
|
||||
)
|
||||
@click.option(
|
||||
"--page-range",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Page range for PDFs (e.g., '1-5,7,9-12'). Only applicable to PDF files.",
|
||||
)
|
||||
@click.option(
|
||||
"--max-output-tokens",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of output tokens per page.",
|
||||
)
|
||||
@click.option(
|
||||
"--max-workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of parallel workers for vLLM inference.",
|
||||
)
|
||||
@click.option(
|
||||
"--max-retries",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of retries for vLLM inference.",
|
||||
)
|
||||
@click.option(
|
||||
"--include-images/--no-images",
|
||||
default=True,
|
||||
help="Include images in output.",
|
||||
)
|
||||
@click.option(
|
||||
"--include-headers-footers/--no-headers-footers",
|
||||
default=False,
|
||||
help="Include page headers and footers in output.",
|
||||
)
|
||||
@click.option(
|
||||
"--save-html/--no-html",
|
||||
default=True,
|
||||
help="Save HTML output files.",
|
||||
)
|
||||
@click.option(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of pages to process in a batch.",
|
||||
)
|
||||
@click.option(
|
||||
"--paginate_output",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
)
|
||||
def main(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
method: str,
|
||||
page_range: str,
|
||||
max_output_tokens: int,
|
||||
max_workers: int,
|
||||
max_retries: int,
|
||||
include_images: bool,
|
||||
include_headers_footers: bool,
|
||||
save_html: bool,
|
||||
batch_size: int,
|
||||
paginate_output: bool,
|
||||
):
|
||||
click.echo("Chandra CLI - Starting OCR processing")
|
||||
click.echo(f"Input: {input_path}")
|
||||
click.echo(f"Output: {output_path}")
|
||||
click.echo(f"Method: {method}")
|
||||
|
||||
# Create output directory
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load model
|
||||
click.echo(f"\nLoading model with method '{method}'...")
|
||||
model = InferenceManager(method=method)
|
||||
click.echo("Model loaded successfully.")
|
||||
|
||||
# Get files to process
|
||||
files_to_process = get_supported_files(input_path)
|
||||
click.echo(f"\nFound {len(files_to_process)} file(s) to process.")
|
||||
|
||||
if not files_to_process:
|
||||
click.echo("No supported files found. Exiting.")
|
||||
return
|
||||
|
||||
# Process each file
|
||||
for file_idx, file_path in enumerate(files_to_process, 1):
|
||||
click.echo(
|
||||
f"\n[{file_idx}/{len(files_to_process)}] Processing: {file_path.name}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load images from file
|
||||
config = {"page_range": page_range} if page_range else {}
|
||||
images = load_file(str(file_path), config)
|
||||
click.echo(f" Loaded {len(images)} page(s)")
|
||||
|
||||
# Accumulate all results for this document
|
||||
all_results = []
|
||||
|
||||
# Process pages in batches
|
||||
for batch_start in range(0, len(images), batch_size):
|
||||
batch_end = min(batch_start + batch_size, len(images))
|
||||
batch_images = images[batch_start:batch_end]
|
||||
|
||||
# Create batch input items
|
||||
batch = [
|
||||
BatchInputItem(image=img, prompt_type="ocr_layout")
|
||||
for img in batch_images
|
||||
]
|
||||
|
||||
# Run inference
|
||||
click.echo(f" Processing pages {batch_start + 1}-{batch_end}...")
|
||||
|
||||
# Build kwargs for generate
|
||||
generate_kwargs = {
|
||||
"include_images": include_images,
|
||||
"include_headers_footers": include_headers_footers,
|
||||
}
|
||||
|
||||
if max_output_tokens is not None:
|
||||
generate_kwargs["max_output_tokens"] = max_output_tokens
|
||||
|
||||
if method == "vllm":
|
||||
if max_workers is not None:
|
||||
generate_kwargs["max_workers"] = max_workers
|
||||
if max_retries is not None:
|
||||
generate_kwargs["max_retries"] = max_retries
|
||||
|
||||
results = model.generate(batch, **generate_kwargs)
|
||||
all_results.extend(results)
|
||||
|
||||
# Save merged output for all pages
|
||||
save_merged_output(
|
||||
output_path,
|
||||
file_path.name,
|
||||
all_results,
|
||||
save_images=include_images,
|
||||
save_html=save_html,
|
||||
paginate_output=paginate_output,
|
||||
)
|
||||
|
||||
click.echo(f" Completed: {file_path.name}")
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f" Error processing {file_path.name}: {e}", err=True)
|
||||
continue
|
||||
|
||||
click.echo(f"\nProcessing complete. Results saved to: {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user