mirror of
https://github.com/datalab-to/chandra.git
synced 2025-11-30 00:53:09 +00:00
Refactor
This commit is contained in:
@@ -2,19 +2,23 @@ import pypdfium2 as pdfium
|
||||
import streamlit as st
|
||||
from PIL import Image
|
||||
|
||||
from chandra.layout import parse_layout, draw_layout
|
||||
from chandra.load import load_pdf_images
|
||||
from chandra.model import load, BatchItem, generate
|
||||
from chandra.model import InferenceManager
|
||||
from chandra.util import draw_layout
|
||||
from chandra.input import load_pdf_images
|
||||
from chandra.model.schema import BatchInputItem
|
||||
from chandra.output import parse_layout
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
def load_model():
|
||||
return load()
|
||||
def load_model(method: str):
|
||||
return InferenceManager(method=method)
|
||||
|
||||
|
||||
@st.cache_data()
|
||||
def get_page_image(pdf_file, page_num):
|
||||
return load_pdf_images(pdf_file, [page_num])[0]
|
||||
|
||||
|
||||
@st.cache_data()
|
||||
def page_counter(pdf_file):
|
||||
doc = pdfium.PdfDocument(pdf_file)
|
||||
@@ -22,40 +26,45 @@ def page_counter(pdf_file):
|
||||
doc.close()
|
||||
return doc_len
|
||||
|
||||
# Function for OCR
|
||||
|
||||
def ocr_layout(
|
||||
img: Image.Image,
|
||||
model=None,
|
||||
) -> (Image.Image, str):
|
||||
batch = BatchItem(
|
||||
images=[img],
|
||||
batch = BatchInputItem(
|
||||
image=img,
|
||||
prompt_type="ocr_layout",
|
||||
)
|
||||
html = generate([batch], model=model)[0]
|
||||
print(f"Generated HTML: {html[:500]}...")
|
||||
layout = parse_layout(html, img)
|
||||
result = model.generate([batch])[0]
|
||||
layout = parse_layout(result.raw, img)
|
||||
layout_image = draw_layout(img, layout)
|
||||
return html, layout_image
|
||||
return result.html, layout_image, result.markdown
|
||||
|
||||
def ocr(
|
||||
img: Image.Image,
|
||||
) -> str:
|
||||
batch = BatchItem(
|
||||
images=[img],
|
||||
prompt_type="ocr"
|
||||
)
|
||||
return generate([batch], model=model)[0]
|
||||
|
||||
st.set_page_config(layout="wide")
|
||||
col1, col2 = st.columns([0.5, 0.5])
|
||||
|
||||
model = load_model()
|
||||
|
||||
st.markdown("""
|
||||
# Chandra OCR Demo
|
||||
|
||||
This app will let you try chandra, a multilingual OCR toolkit.
|
||||
This app will let you try chandra, a layout-aware vision language model.
|
||||
""")
|
||||
|
||||
# Get model mode selection
|
||||
model_mode = st.sidebar.selectbox(
|
||||
"Model Mode",
|
||||
["None", "hf", "vllm"],
|
||||
index=0,
|
||||
help="Select how to run inference: hf loads the model in memory using huggingface transformers, vllm connects to a running vLLM server."
|
||||
)
|
||||
|
||||
# Only load model if a mode is selected
|
||||
model = None
|
||||
if model_mode == "None":
|
||||
st.warning("Please select a model mode (Local Model or vLLM Server) to run OCR.")
|
||||
else:
|
||||
model = load_model(model_mode)
|
||||
|
||||
in_file = st.sidebar.file_uploader(
|
||||
"PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"]
|
||||
)
|
||||
@@ -77,37 +86,35 @@ else:
|
||||
page_number = None
|
||||
|
||||
run_ocr = st.sidebar.button("Run OCR")
|
||||
prompt_type = st.sidebar.selectbox(
|
||||
"Prompt type",
|
||||
["ocr_layout", "ocr"],
|
||||
index=0,
|
||||
help="Select the prompt type for OCR.",
|
||||
)
|
||||
|
||||
if pil_image is None:
|
||||
st.stop()
|
||||
|
||||
if run_ocr:
|
||||
if prompt_type == "ocr_layout":
|
||||
pred, layout_image = ocr_layout(
|
||||
pil_image,
|
||||
)
|
||||
if model_mode == "None":
|
||||
st.error("Please select a model mode (hf or vllm) to run OCR.")
|
||||
else:
|
||||
pred = ocr(
|
||||
pred, layout_image, markdown = ocr_layout(
|
||||
pil_image,
|
||||
model,
|
||||
)
|
||||
layout_image = None
|
||||
|
||||
with col1:
|
||||
html_tab, text_tab, layout_tab = st.tabs(["HTML", "HTML as text", "Layout Image"])
|
||||
with html_tab:
|
||||
st.markdown(pred, unsafe_allow_html=True)
|
||||
with text_tab:
|
||||
st.text(pred)
|
||||
with col1:
|
||||
html_tab, text_tab, layout_tab = st.tabs(["HTML", "HTML as text", "Layout Image"])
|
||||
with html_tab:
|
||||
st.markdown(markdown, unsafe_allow_html=True)
|
||||
st.download_button(
|
||||
label="Download Markdown",
|
||||
data=markdown,
|
||||
file_name=f"{in_file.name.rsplit('.', 1)[0]}_page{page_number if page_number is not None else 0}.md",
|
||||
mime="text/markdown",
|
||||
)
|
||||
with text_tab:
|
||||
st.text(pred)
|
||||
|
||||
if layout_image:
|
||||
with layout_tab:
|
||||
st.image(layout_image, caption="Detected Layout", use_container_width=True)
|
||||
if layout_image:
|
||||
with layout_tab:
|
||||
st.image(layout_image, caption="Detected Layout", use_container_width=True)
|
||||
|
||||
with col2:
|
||||
st.image(pil_image, caption="Uploaded Image", use_container_width=True)
|
||||
|
||||
Reference in New Issue
Block a user