Coverage for src / lilbee / providers / mtmd_backend.py: 100%

78 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""Vision OCR loader that drives llama.cpp's mtmd pipeline with the GGUF's 

2own chat template, so there's no projector-type-to-handler lookup table. 

3""" 

4 

5from __future__ import annotations 

6 

7import logging 

8import re 

9from pathlib import Path 

10from typing import Any 

11 

12from gguf import GGUFReader 

13 

14from lilbee.catalog.header_probe import GGUF_ARCH_KEY 

15from lilbee.core.config import cfg 

16from lilbee.providers.llama_cpp.abort_signal import abort_callback 

17from lilbee.providers.llama_cpp.gguf_meta import ( 

18 find_mmproj_for_model, 

19 read_gguf_metadata, 

20 train_ctx_from_meta, 

21) 

22from lilbee.providers.llama_cpp.log_dispatch import ( 

23 import_llama_cpp, 

24 install_llama_log_handler, 

25 suppress_native_stderr, 

26) 

27 

28log = logging.getLogger(__name__) 

29 

30 

31# Image-placeholder tokens seen in GGUF chat templates. The upstream 

32# mtmd pipeline substitutes image URLs with mtmd's media marker, so 

33# these get rewritten to {{ content.image_url.url }} before rendering. 

34# Case matters: GGUF templates are machine-emitted and stable, so a 

35# case-insensitive replace would risk corrupting unrelated Jinja 

36# identifiers. 

37_GGUF_IMAGE_TOKENS: tuple[str, ...] = ( 

38 "<|image_pad|>", 

39 "<image>", 

40 "<IMAGE>", 

41 "<__media__>", 

42 "<__image__>", 

43) 

44_IMAGE_URL_JINJA = "{{ content.image_url.url }}" 

45 

46# Angle-bracket placeholder whose inner text names an image/media slot, e.g. 

47# ``<start_of_image>`` or ``<media>``. Used after the known-token rewrite to 

48# catch an unrecognized image placeholder; the adapted ``{{ ... }}`` Jinja marker 

49# uses braces, not angle brackets, so it never matches here. 

50_UNKNOWN_IMAGE_TOKEN_RE = re.compile( 

51 r"<\|?[^<>]*(?:image|img|media|vision)[^<>]*\|?>", re.IGNORECASE 

52) 

53 

54_TOKENIZER_CHAT_TEMPLATE_KEY = "tokenizer.chat_template" 

55 

56_VISION_FALLBACK_N_CTX = 4096 

57"""n_ctx for a vision load when the GGUF has no ``context_length`` in metadata. 

58 

59Most vision GGUFs report their training context (typical values: 4096, 8192, 

6032768); this covers the rare missing/unreadable-metadata case so the loader 

61still gets a sensible explicit n_ctx. 

62""" 

63 

64 

65def read_chat_template(model_path: Path) -> str | None: 

66 """Return the Jinja chat template embedded in a GGUF model, or None.""" 

67 try: 

68 reader = GGUFReader(str(model_path)) 

69 field = reader.get_field(_TOKENIZER_CHAT_TEMPLATE_KEY) 

70 except (OSError, ValueError, IndexError, KeyError): 

71 log.debug("Failed to read chat template from %s", model_path, exc_info=True) 

72 return None 

73 if field is None: 

74 return None 

75 return bytes(field.parts[field.data[0]]).decode("utf-8", errors="replace") 

76 

77 

78def adapt_gguf_template_for_mtmd(template: str) -> str: 

79 """Rewrite known image-placeholder tokens to ``{{ content.image_url.url }}``. 

80 

81 Raises ``ValueError`` if an unrecognized image/media placeholder survives, 

82 so a corrupted vision prompt fails at load rather than degrading OCR silently. 

83 """ 

84 for token in _GGUF_IMAGE_TOKENS: 

85 if token in template: 

86 template = template.replace(token, _IMAGE_URL_JINJA) 

87 leftover = _UNKNOWN_IMAGE_TOKEN_RE.search(template) 

88 if leftover is not None: 

89 supported = ", ".join(_GGUF_IMAGE_TOKENS) 

90 raise ValueError( 

91 f"Unrecognized image placeholder {leftover.group()!r} in GGUF chat template; " 

92 f"supported tokens are: {supported}" 

93 ) 

94 return template 

95 

96 

97def build_vision_chat_handler(model_path: Path, mmproj_path: Path) -> Any: 

98 """Return the mtmd chat handler configured with the GGUF's embedded template. 

99 

100 ``DEFAULT_SYSTEM_MESSAGE`` is set to ``None`` so no stray system turn 

101 is injected. Falls back to the upstream default template when the 

102 GGUF has no ``tokenizer.chat_template``. 

103 """ 

104 # Surface the libvulkan-missing hint before submodule import, since 

105 # importing llama_cpp.llama_chat_format triggers the parent package's 

106 # native loader as a side effect. 

107 import_llama_cpp() 

108 from llama_cpp.llama_chat_format import Llava15ChatHandler 

109 

110 # Defined per call so each loaded model binds its own ``CHAT_FORMAT`` 

111 # (set below) to a fresh class; hoisting this to module scope would 

112 # make the first loaded model's template leak into every subsequent 

113 # one. 

114 class _GgufTemplateChatHandler(Llava15ChatHandler): 

115 DEFAULT_SYSTEM_MESSAGE = None 

116 

117 handler_cls: type[Llava15ChatHandler] = _GgufTemplateChatHandler 

118 

119 template = read_chat_template(model_path) 

120 if template is not None: 

121 handler_cls.CHAT_FORMAT = adapt_gguf_template_for_mtmd(template) 

122 log.info( 

123 "Vision chat handler: using GGUF-embedded template (%d bytes) from %s", 

124 len(template), 

125 model_path.name, 

126 ) 

127 else: 

128 log.info( 

129 "Vision chat handler: no GGUF-embedded chat template for %s; using upstream default", 

130 model_path.name, 

131 ) 

132 

133 return handler_cls(str(mmproj_path), verbose=False) 

134 

135 

136def load_vision_llama( 

137 model_path: Path, 

138 mmproj_path: Path | None = None, 

139 *, 

140 abort_callback_override: Any = None, 

141) -> Any: 

142 """Load a vision-capable ``Llama`` using the GGUF-templated chat handler. 

143 

144 ``abort_callback_override`` lets pool workers bind a callback that 

145 reads the worker's shared ``mp.Value`` abort flag. 

146 """ 

147 Llama = import_llama_cpp().Llama # noqa: N806 # heavy native lib; keep import lazy 

148 

149 install_llama_log_handler() 

150 if mmproj_path is None: 

151 mmproj_path = find_mmproj_for_model(model_path) 

152 

153 chat_handler = build_vision_chat_handler(model_path, mmproj_path) 

154 

155 import os 

156 

157 # mtmd offloads the vision encoder to the GPU (use_gpu=True), but the 

158 # CPU side still runs image preprocessing, tokenization, and sampling. 

159 # llama-cpp-python defaults n_threads to ~cpu_count()//2, which starves 

160 # those steps; Ollama runs full-core, so match that here. 

161 n_threads = os.cpu_count() or 4 

162 kwargs: dict[str, Any] = { 

163 "model_path": str(model_path), 

164 "chat_handler": chat_handler, 

165 "verbose": False, 

166 "n_gpu_layers": -1, 

167 "n_ctx": _resolve_vision_n_ctx(model_path), 

168 "n_threads": n_threads, 

169 "n_threads_batch": n_threads, 

170 } 

171 # mtmd only enables flash attention for the vision pass when the backing 

172 # Llama already has it ENABLED; left unset it defaults to DISABLED, which 

173 # makes the image-token prefill (thousands of tokens per page) far slower. 

174 # Mirror the chat loader and honor cfg.flash_attention (None/auto -> on). 

175 if cfg.flash_attention is not False: 

176 kwargs["flash_attn"] = True 

177 if cfg.main_gpu is not None: 

178 kwargs["main_gpu"] = cfg.main_gpu 

179 if abort_callback_override is not None: 

180 kwargs["abort_callback"] = abort_callback_override 

181 else: 

182 kwargs.setdefault("abort_callback", abort_callback) 

183 

184 llama = suppress_native_stderr(Llama, **kwargs) 

185 metadata = getattr(llama, "metadata", {}) or {} 

186 n_ctx_fn = getattr(llama, "n_ctx", None) 

187 n_ctx = n_ctx_fn() if callable(n_ctx_fn) else "?" 

188 log.info( 

189 "Vision model loaded: model=%s mmproj=%s n_ctx=%s arch=%s", 

190 model_path.name, 

191 mmproj_path.name, 

192 n_ctx, 

193 metadata.get(GGUF_ARCH_KEY, "?"), 

194 ) 

195 return llama 

196 

197 

198def _resolve_vision_n_ctx(model_path: Path) -> int: 

199 """Pick n_ctx for a vision load using the model's training context. 

200 

201 Reads ``<arch>.context_length`` from the GGUF metadata and uses it 

202 directly. The chat-tuned ``cfg.num_ctx`` is not propagated: a vision pass 

203 packs image-token embeddings plus the prompt (often hundreds to a few 

204 thousand tokens per page), and clamping to a small chat ctx truncates OCR 

205 output. An explicit value (rather than 0) keeps the OOM-retry path 

206 working since ``_halve_ctx_for_retry`` cannot bisect from 0. 

207 """ 

208 try: 

209 meta = read_gguf_metadata(model_path) 

210 except Exception: 

211 log.debug("read_gguf_metadata failed for vision %s", model_path, exc_info=True) 

212 meta = None 

213 return train_ctx_from_meta(meta, fallback=_VISION_FALLBACK_N_CTX, model_path=model_path)