Coverage for src / lilbee / providers / mtmd_backend.py: 100%
78 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""Vision OCR loader that drives llama.cpp's mtmd pipeline with the GGUF's
2own chat template, so there's no projector-type-to-handler lookup table.
3"""
5from __future__ import annotations
7import logging
8import re
9from pathlib import Path
10from typing import Any
12from gguf import GGUFReader
14from lilbee.catalog.header_probe import GGUF_ARCH_KEY
15from lilbee.core.config import cfg
16from lilbee.providers.llama_cpp.abort_signal import abort_callback
17from lilbee.providers.llama_cpp.gguf_meta import (
18 find_mmproj_for_model,
19 read_gguf_metadata,
20 train_ctx_from_meta,
21)
22from lilbee.providers.llama_cpp.log_dispatch import (
23 import_llama_cpp,
24 install_llama_log_handler,
25 suppress_native_stderr,
26)
28log = logging.getLogger(__name__)
31# Image-placeholder tokens seen in GGUF chat templates. The upstream
32# mtmd pipeline substitutes image URLs with mtmd's media marker, so
33# these get rewritten to {{ content.image_url.url }} before rendering.
34# Case matters: GGUF templates are machine-emitted and stable, so a
35# case-insensitive replace would risk corrupting unrelated Jinja
36# identifiers.
37_GGUF_IMAGE_TOKENS: tuple[str, ...] = (
38 "<|image_pad|>",
39 "<image>",
40 "<IMAGE>",
41 "<__media__>",
42 "<__image__>",
43)
44_IMAGE_URL_JINJA = "{{ content.image_url.url }}"
46# Angle-bracket placeholder whose inner text names an image/media slot, e.g.
47# ``<start_of_image>`` or ``<media>``. Used after the known-token rewrite to
48# catch an unrecognized image placeholder; the adapted ``{{ ... }}`` Jinja marker
49# uses braces, not angle brackets, so it never matches here.
50_UNKNOWN_IMAGE_TOKEN_RE = re.compile(
51 r"<\|?[^<>]*(?:image|img|media|vision)[^<>]*\|?>", re.IGNORECASE
52)
54_TOKENIZER_CHAT_TEMPLATE_KEY = "tokenizer.chat_template"
56_VISION_FALLBACK_N_CTX = 4096
57"""n_ctx for a vision load when the GGUF has no ``context_length`` in metadata.
59Most vision GGUFs report their training context (typical values: 4096, 8192,
6032768); this covers the rare missing/unreadable-metadata case so the loader
61still gets a sensible explicit n_ctx.
62"""
65def read_chat_template(model_path: Path) -> str | None:
66 """Return the Jinja chat template embedded in a GGUF model, or None."""
67 try:
68 reader = GGUFReader(str(model_path))
69 field = reader.get_field(_TOKENIZER_CHAT_TEMPLATE_KEY)
70 except (OSError, ValueError, IndexError, KeyError):
71 log.debug("Failed to read chat template from %s", model_path, exc_info=True)
72 return None
73 if field is None:
74 return None
75 return bytes(field.parts[field.data[0]]).decode("utf-8", errors="replace")
78def adapt_gguf_template_for_mtmd(template: str) -> str:
79 """Rewrite known image-placeholder tokens to ``{{ content.image_url.url }}``.
81 Raises ``ValueError`` if an unrecognized image/media placeholder survives,
82 so a corrupted vision prompt fails at load rather than degrading OCR silently.
83 """
84 for token in _GGUF_IMAGE_TOKENS:
85 if token in template:
86 template = template.replace(token, _IMAGE_URL_JINJA)
87 leftover = _UNKNOWN_IMAGE_TOKEN_RE.search(template)
88 if leftover is not None:
89 supported = ", ".join(_GGUF_IMAGE_TOKENS)
90 raise ValueError(
91 f"Unrecognized image placeholder {leftover.group()!r} in GGUF chat template; "
92 f"supported tokens are: {supported}"
93 )
94 return template
97def build_vision_chat_handler(model_path: Path, mmproj_path: Path) -> Any:
98 """Return the mtmd chat handler configured with the GGUF's embedded template.
100 ``DEFAULT_SYSTEM_MESSAGE`` is set to ``None`` so no stray system turn
101 is injected. Falls back to the upstream default template when the
102 GGUF has no ``tokenizer.chat_template``.
103 """
104 # Surface the libvulkan-missing hint before submodule import, since
105 # importing llama_cpp.llama_chat_format triggers the parent package's
106 # native loader as a side effect.
107 import_llama_cpp()
108 from llama_cpp.llama_chat_format import Llava15ChatHandler
110 # Defined per call so each loaded model binds its own ``CHAT_FORMAT``
111 # (set below) to a fresh class; hoisting this to module scope would
112 # make the first loaded model's template leak into every subsequent
113 # one.
114 class _GgufTemplateChatHandler(Llava15ChatHandler):
115 DEFAULT_SYSTEM_MESSAGE = None
117 handler_cls: type[Llava15ChatHandler] = _GgufTemplateChatHandler
119 template = read_chat_template(model_path)
120 if template is not None:
121 handler_cls.CHAT_FORMAT = adapt_gguf_template_for_mtmd(template)
122 log.info(
123 "Vision chat handler: using GGUF-embedded template (%d bytes) from %s",
124 len(template),
125 model_path.name,
126 )
127 else:
128 log.info(
129 "Vision chat handler: no GGUF-embedded chat template for %s; using upstream default",
130 model_path.name,
131 )
133 return handler_cls(str(mmproj_path), verbose=False)
136def load_vision_llama(
137 model_path: Path,
138 mmproj_path: Path | None = None,
139 *,
140 abort_callback_override: Any = None,
141) -> Any:
142 """Load a vision-capable ``Llama`` using the GGUF-templated chat handler.
144 ``abort_callback_override`` lets pool workers bind a callback that
145 reads the worker's shared ``mp.Value`` abort flag.
146 """
147 Llama = import_llama_cpp().Llama # noqa: N806 # heavy native lib; keep import lazy
149 install_llama_log_handler()
150 if mmproj_path is None:
151 mmproj_path = find_mmproj_for_model(model_path)
153 chat_handler = build_vision_chat_handler(model_path, mmproj_path)
155 import os
157 # mtmd offloads the vision encoder to the GPU (use_gpu=True), but the
158 # CPU side still runs image preprocessing, tokenization, and sampling.
159 # llama-cpp-python defaults n_threads to ~cpu_count()//2, which starves
160 # those steps; Ollama runs full-core, so match that here.
161 n_threads = os.cpu_count() or 4
162 kwargs: dict[str, Any] = {
163 "model_path": str(model_path),
164 "chat_handler": chat_handler,
165 "verbose": False,
166 "n_gpu_layers": -1,
167 "n_ctx": _resolve_vision_n_ctx(model_path),
168 "n_threads": n_threads,
169 "n_threads_batch": n_threads,
170 }
171 # mtmd only enables flash attention for the vision pass when the backing
172 # Llama already has it ENABLED; left unset it defaults to DISABLED, which
173 # makes the image-token prefill (thousands of tokens per page) far slower.
174 # Mirror the chat loader and honor cfg.flash_attention (None/auto -> on).
175 if cfg.flash_attention is not False:
176 kwargs["flash_attn"] = True
177 if cfg.main_gpu is not None:
178 kwargs["main_gpu"] = cfg.main_gpu
179 if abort_callback_override is not None:
180 kwargs["abort_callback"] = abort_callback_override
181 else:
182 kwargs.setdefault("abort_callback", abort_callback)
184 llama = suppress_native_stderr(Llama, **kwargs)
185 metadata = getattr(llama, "metadata", {}) or {}
186 n_ctx_fn = getattr(llama, "n_ctx", None)
187 n_ctx = n_ctx_fn() if callable(n_ctx_fn) else "?"
188 log.info(
189 "Vision model loaded: model=%s mmproj=%s n_ctx=%s arch=%s",
190 model_path.name,
191 mmproj_path.name,
192 n_ctx,
193 metadata.get(GGUF_ARCH_KEY, "?"),
194 )
195 return llama
198def _resolve_vision_n_ctx(model_path: Path) -> int:
199 """Pick n_ctx for a vision load using the model's training context.
201 Reads ``<arch>.context_length`` from the GGUF metadata and uses it
202 directly. The chat-tuned ``cfg.num_ctx`` is not propagated: a vision pass
203 packs image-token embeddings plus the prompt (often hundreds to a few
204 thousand tokens per page), and clamping to a small chat ctx truncates OCR
205 output. An explicit value (rather than 0) keeps the OOM-retry path
206 working since ``_halve_ctx_for_retry`` cannot bisect from 0.
207 """
208 try:
209 meta = read_gguf_metadata(model_path)
210 except Exception:
211 log.debug("read_gguf_metadata failed for vision %s", model_path, exc_info=True)
212 meta = None
213 return train_ctx_from_meta(meta, fallback=_VISION_FALLBACK_N_CTX, model_path=model_path)