mirror of
https://github.com/datalab-to/chandra.git
synced 2025-11-29 00:23:12 +00:00
Compare commits
4 Commits
7d967717a3
...
cba67c6d15
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cba67c6d15 | ||
|
|
c049e7524f | ||
|
|
7ac08e16e1 | ||
|
|
b96eb84094 |
@@ -1,8 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
|
||||
|
||||
from chandra.model.schema import BatchInputItem, GenerationResult
|
||||
from chandra.model.util import scale_to_fit
|
||||
from chandra.prompts import PROMPT_MAPPING
|
||||
@@ -16,6 +13,8 @@ def generate_hf(
|
||||
bbox_scale: int = settings.BBOX_SCALE,
|
||||
**kwargs,
|
||||
) -> List[GenerationResult]:
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
if max_output_tokens is None:
|
||||
max_output_tokens = settings.MAX_OUTPUT_TOKENS
|
||||
|
||||
@@ -71,12 +70,15 @@ def process_batch_element(item: BatchInputItem, processor, bbox_scale: int):
|
||||
|
||||
|
||||
def load_model():
|
||||
import torch
|
||||
from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
|
||||
|
||||
device_map = "auto"
|
||||
if settings.TORCH_DEVICE:
|
||||
device_map = {"": settings.TORCH_DEVICE}
|
||||
|
||||
kwargs = {
|
||||
"dtype": settings.TORCH_DTYPE,
|
||||
"dtype": torch.bfloat16,
|
||||
"device_map": device_map,
|
||||
}
|
||||
if settings.TORCH_ATTN:
|
||||
|
||||
@@ -6,7 +6,7 @@ from functools import lru_cache
|
||||
|
||||
import six
|
||||
from PIL import Image
|
||||
from bs4 import BeautifulSoup, NavigableString
|
||||
from bs4 import BeautifulSoup
|
||||
from markdownify import MarkdownConverter, re_whitespace
|
||||
|
||||
from chandra.settings import settings
|
||||
@@ -89,39 +89,6 @@ def parse_html(
|
||||
return out_html
|
||||
|
||||
|
||||
def escape_dollars(text):
|
||||
return text.replace("$", r"\$")
|
||||
|
||||
|
||||
def get_formatted_table_text(element):
|
||||
text = []
|
||||
for content in element.contents:
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
if isinstance(content, NavigableString):
|
||||
stripped = content.strip()
|
||||
if stripped:
|
||||
text.append(escape_dollars(stripped))
|
||||
elif content.name == "br":
|
||||
text.append("<br>")
|
||||
elif content.name == "math":
|
||||
text.append("$" + content.text + "$")
|
||||
else:
|
||||
content_str = escape_dollars(str(content))
|
||||
text.append(content_str)
|
||||
|
||||
full_text = ""
|
||||
for i, t in enumerate(text):
|
||||
if t == "<br>":
|
||||
full_text += t
|
||||
elif i > 0 and text[i - 1] != "<br>":
|
||||
full_text += " " + t
|
||||
else:
|
||||
full_text += t
|
||||
return full_text
|
||||
|
||||
|
||||
class Markdownify(MarkdownConverter):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from dotenv import find_dotenv
|
||||
from pydantic import computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
import torch
|
||||
import os
|
||||
|
||||
|
||||
@@ -24,11 +22,6 @@ class Settings(BaseSettings):
|
||||
VLLM_GPUS: str = "0"
|
||||
MAX_VLLM_RETRIES: int = 6
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def TORCH_DTYPE(self) -> torch.dtype:
|
||||
return torch.bfloat16
|
||||
|
||||
class Config:
|
||||
env_file = find_dotenv("local.env")
|
||||
extra = "ignore"
|
||||
|
||||
Reference in New Issue
Block a user