mirror of
https://github.com/datalab-to/chandra.git
synced 2025-11-29 00:23:12 +00:00
Compare commits
4 Commits
7d967717a3
...
cba67c6d15
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cba67c6d15 | ||
|
|
c049e7524f | ||
|
|
7ac08e16e1 | ||
|
|
b96eb84094 |
@@ -1,8 +1,5 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from qwen_vl_utils import process_vision_info
|
|
||||||
from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
|
|
||||||
|
|
||||||
from chandra.model.schema import BatchInputItem, GenerationResult
|
from chandra.model.schema import BatchInputItem, GenerationResult
|
||||||
from chandra.model.util import scale_to_fit
|
from chandra.model.util import scale_to_fit
|
||||||
from chandra.prompts import PROMPT_MAPPING
|
from chandra.prompts import PROMPT_MAPPING
|
||||||
@@ -16,6 +13,8 @@ def generate_hf(
|
|||||||
bbox_scale: int = settings.BBOX_SCALE,
|
bbox_scale: int = settings.BBOX_SCALE,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[GenerationResult]:
|
) -> List[GenerationResult]:
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
|
||||||
if max_output_tokens is None:
|
if max_output_tokens is None:
|
||||||
max_output_tokens = settings.MAX_OUTPUT_TOKENS
|
max_output_tokens = settings.MAX_OUTPUT_TOKENS
|
||||||
|
|
||||||
@@ -71,12 +70,15 @@ def process_batch_element(item: BatchInputItem, processor, bbox_scale: int):
|
|||||||
|
|
||||||
|
|
||||||
def load_model():
|
def load_model():
|
||||||
|
import torch
|
||||||
|
from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
|
||||||
|
|
||||||
device_map = "auto"
|
device_map = "auto"
|
||||||
if settings.TORCH_DEVICE:
|
if settings.TORCH_DEVICE:
|
||||||
device_map = {"": settings.TORCH_DEVICE}
|
device_map = {"": settings.TORCH_DEVICE}
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"dtype": settings.TORCH_DTYPE,
|
"dtype": torch.bfloat16,
|
||||||
"device_map": device_map,
|
"device_map": device_map,
|
||||||
}
|
}
|
||||||
if settings.TORCH_ATTN:
|
if settings.TORCH_ATTN:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from functools import lru_cache
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from bs4 import BeautifulSoup, NavigableString
|
from bs4 import BeautifulSoup
|
||||||
from markdownify import MarkdownConverter, re_whitespace
|
from markdownify import MarkdownConverter, re_whitespace
|
||||||
|
|
||||||
from chandra.settings import settings
|
from chandra.settings import settings
|
||||||
@@ -89,39 +89,6 @@ def parse_html(
|
|||||||
return out_html
|
return out_html
|
||||||
|
|
||||||
|
|
||||||
def escape_dollars(text):
|
|
||||||
return text.replace("$", r"\$")
|
|
||||||
|
|
||||||
|
|
||||||
def get_formatted_table_text(element):
|
|
||||||
text = []
|
|
||||||
for content in element.contents:
|
|
||||||
if content is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(content, NavigableString):
|
|
||||||
stripped = content.strip()
|
|
||||||
if stripped:
|
|
||||||
text.append(escape_dollars(stripped))
|
|
||||||
elif content.name == "br":
|
|
||||||
text.append("<br>")
|
|
||||||
elif content.name == "math":
|
|
||||||
text.append("$" + content.text + "$")
|
|
||||||
else:
|
|
||||||
content_str = escape_dollars(str(content))
|
|
||||||
text.append(content_str)
|
|
||||||
|
|
||||||
full_text = ""
|
|
||||||
for i, t in enumerate(text):
|
|
||||||
if t == "<br>":
|
|
||||||
full_text += t
|
|
||||||
elif i > 0 and text[i - 1] != "<br>":
|
|
||||||
full_text += " " + t
|
|
||||||
else:
|
|
||||||
full_text += t
|
|
||||||
return full_text
|
|
||||||
|
|
||||||
|
|
||||||
class Markdownify(MarkdownConverter):
|
class Markdownify(MarkdownConverter):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
from dotenv import find_dotenv
|
from dotenv import find_dotenv
|
||||||
from pydantic import computed_field
|
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
import torch
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
@@ -24,11 +22,6 @@ class Settings(BaseSettings):
|
|||||||
VLLM_GPUS: str = "0"
|
VLLM_GPUS: str = "0"
|
||||||
MAX_VLLM_RETRIES: int = 6
|
MAX_VLLM_RETRIES: int = 6
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def TORCH_DTYPE(self) -> torch.dtype:
|
|
||||||
return torch.bfloat16
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = find_dotenv("local.env")
|
env_file = find_dotenv("local.env")
|
||||||
extra = "ignore"
|
extra = "ignore"
|
||||||
|
|||||||
Reference in New Issue
Block a user