From 63c88d644da723be991d81ed83add9a41aa3067f Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Tue, 21 Oct 2025 11:01:02 -0400 Subject: [PATCH] Fix attn impl --- .gitignore | 1 + README.md | 20 ++++++++++---------- chandra/model/hf.py | 13 +++++++++---- chandra/settings.py | 11 ----------- pyproject.toml | 2 +- 5 files changed, 21 insertions(+), 26 deletions(-) diff --git a/.gitignore b/.gitignore index 9676325..158cdbd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ local.env experiments .claude +.DS_Store # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/README.md b/README.md index 9af3fe3..e1ba99a 100644 --- a/README.md +++ b/README.md @@ -154,17 +154,17 @@ VLLM_GPUS=0 ## Benchmark table -| **Model** | ArXiv | Old Scans Math | Tables | Old Scans | Headers and Footers | Multi column | Long tiny text | Base | Overall | Source | -|:----------|:--------:|:--------------:|:--------:|:---------:|:-------------------:|:------------:|:--------------:|:----:|:--------------:|:------:| -| Datalab Chandra v0.1.0 | 82.2 | **80.3** | **88.0** | **50.4** | 90.8 | 81.2 | **92.3** | **99.9** | **83.1 ± 0.9** | Own benchmarks | -| Datalab Marker v1.10.0 | **83.8** | 69.7 | 74.8 | 32.3 | 86.6 | 79.4 | 85.7 | 99.6 | 76.5 ± 1.0 | Own benchmarks | -| Mistral OCR API | 77.2 | 67.5 | 60.6 | 29.3 | 93.6 | 71.3 | 77.1 | 99.4 | 72.0 ± 1.1 | olmocr repo | -| Deepseek OCR | 75.2 | 72.3 | 79.7 | 33.3 | 96.1 | 66.7 | 80.1 | 99.7 | 75.4 ± 1.0 | Own benchmarks | -| GPT-4o (Anchored) | 53.5 | 74.5 | 70.0 | 40.7 | 93.8 | 69.3 | 60.6 | 96.8 | 69.9 ± 1.1 | olmocr repo | +| **Model** | ArXiv | Old Scans Math | Tables | Old Scans | Headers and Footers | Multi column | Long tiny text | Base | Overall | Source | +|:--------------------------|:--------:|:--------------:|:--------:|:---------:|:-------------------:|:------------:|:--------------:|:----:|:--------------:|:------:| +| Datalab Chandra v0.1.0 | 82.2 | **80.3** | **88.0** | **50.4** | 90.8 | 81.2 | **92.3** | **99.9** | **83.1 ± 0.9** | Own benchmarks | +| Datalab Marker v1.10.0 | **83.8** | 69.7 | 74.8 | 32.3 | 86.6 | 79.4 | 85.7 | 99.6 | 76.5 ± 1.0 | Own benchmarks | +| Mistral OCR API | 77.2 | 67.5 | 60.6 | 29.3 | 93.6 | 71.3 | 77.1 | 99.4 | 72.0 ± 1.1 | olmocr repo | +| Deepseek OCR | 75.2 | 72.3 | 79.7 | 33.3 | 96.1 | 66.7 | 80.1 | 99.7 | 75.4 ± 1.0 | Own benchmarks | +| GPT-4o (Anchored) | 53.5 | 74.5 | 70.0 | 40.7 | 93.8 | 69.3 | 60.6 | 96.8 | 69.9 ± 1.1 | olmocr repo | | Gemini Flash 2 (Anchored) | 54.5 | 56.1 | 72.1 | 34.2 | 64.7 | 61.5 | 71.5 | 95.6 | 63.8 ± 1.2 | olmocr repo | -| Qwen 3 VL | 70.2 | 75.1 | 45.6 | 37.5 | 89.1 | 62.1 | 43.0 | 94.3 | 64.6 ± 1.1 | Own benchmarks | -| olmOCR v0.3.0 | 78.6 | 79.9 | 72.9 | 43.9 | **95.1** | 77.3 | 81.2 | 98.9 | 78.5 ± 1.1 | olmocr repo | -| dots.ocr | 82.1 | 64.2 | 88.3 | 40.9 | 94.1 | **82.4** | 81.2 | 99.5 | 79.1 ± 1.0 | dots.ocr repo | +| Qwen 3 VL 8B | 70.2 | 75.1 | 45.6 | 37.5 | 89.1 | 62.1 | 43.0 | 94.3 | 64.6 ± 1.1 | Own benchmarks | +| olmOCR v0.3.0 | 78.6 | 79.9 | 72.9 | 43.9 | **95.1** | 77.3 | 81.2 | 98.9 | 78.5 ± 1.1 | olmocr repo | +| dots.ocr | 82.1 | 64.2 | 88.3 | 40.9 | 94.1 | **82.4** | 81.2 | 99.5 | 79.1 ± 1.0 | dots.ocr repo | # Commercial usage diff --git a/chandra/model/hf.py b/chandra/model/hf.py index 37f4cba..6bedf49 100644 --- a/chandra/model/hf.py +++ b/chandra/model/hf.py @@ -68,11 +68,16 @@ def load_model(): device_map = "auto" if settings.TORCH_DEVICE: device_map = {"": settings.TORCH_DEVICE} + + kwargs = { + "dtype": settings.TORCH_DTYPE, + "device_map": device_map, + } + if settings.TORCH_ATTN: + kwargs["attn_implementation"] = settings.TORCH_ATTN_IMPLEMENTATION + model = Qwen3VLForConditionalGeneration.from_pretrained( - settings.MODEL_CHECKPOINT, - dtype=settings.TORCH_DTYPE, - device_map=device_map, - attn_implementation=settings.TORCH_ATTN_IMPLEMENTATION, + settings.MODEL_CHECKPOINT, **kwargs ) model = model.eval() processor = Qwen3VLProcessor.from_pretrained(settings.MODEL_CHECKPOINT) diff --git a/chandra/settings.py b/chandra/settings.py index 151472b..2c59ec3 100644 --- a/chandra/settings.py +++ b/chandra/settings.py @@ -42,17 +42,6 @@ class Settings(BaseSettings): def TORCH_DTYPE(self) -> torch.dtype: return torch.bfloat16 - @computed_field - @property - def TORCH_ATTN_IMPLEMENTATION(self) -> str: - if self.TORCH_ATTN is not None: - return self.TORCH_ATTN - - if self.TORCH_DEVICE_MODEL == "cuda": - return "flash_attention_2" - else: - return "sdpa" - class Config: env_file = find_dotenv("local.env") extra = "ignore" diff --git a/pyproject.toml b/pyproject.toml index 2997e1a..fd67ac2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "chandra-ocr" -version = "0.1.1" +version = "0.1.2" description = "OCR model that converts documents to markdown, HTML, or JSON." readme = "README.md" requires-python = ">=3.10"