diff --git a/chandra/model/hf.py b/chandra/model/hf.py index 6bedf49..0803f7d 100644 --- a/chandra/model/hf.py +++ b/chandra/model/hf.py @@ -73,7 +73,7 @@ def load_model(): "dtype": settings.TORCH_DTYPE, "device_map": device_map, } - if settings.TORCH_ATTN: + if settings.TORCH_ATTN_IMPLEMENTATION: kwargs["attn_implementation"] = settings.TORCH_ATTN_IMPLEMENTATION model = Qwen3VLForConditionalGeneration.from_pretrained( diff --git a/chandra/settings.py b/chandra/settings.py index 2c59ec3..67bb55d 100644 --- a/chandra/settings.py +++ b/chandra/settings.py @@ -13,7 +13,7 @@ class Settings(BaseSettings): MODEL_CHECKPOINT: str = "datalab-to/chandra" TORCH_DEVICE: str | None = None MAX_OUTPUT_TOKENS: int = 8192 - TORCH_ATTN: str | None = None + TORCH_ATTN_IMPLEMENTATION: str | None = None # vLLM server settings VLLM_API_KEY: str = "EMPTY"