Coverage for src / lilbee / providers / worker / vision_worker.py: 100%
106 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Long-lived vision-OCR worker subprocess body.
3Hosts both single-image OCR (``WireKind.VISION``, used by the live wiki /
4catalog flows) and multi-page PDF vision OCR (``WireKind.PDF_OCR``, used by
5the ingest pipeline). The PDF path streams one chunk per page so
6subscribers see incremental progress. Tesseract OCR has no shared model
7state and runs inline via ``asyncio.to_thread`` in the ingest caller,
8not through this worker.
9"""
11from __future__ import annotations
13import contextlib
14import logging
15import time
16from pathlib import Path
17from typing import Any
19from lilbee.providers.worker.transport import PdfOcrRequest, RoleConfig, VisionRequest
20from lilbee.providers.worker.transport_pipe import _serialize_exception
21from lilbee.providers.worker.wire_kinds import WireKind
22from lilbee.providers.worker.worker_runtime import Reply, WorkerLoopState, run_worker
23from lilbee.vision import (
24 OCR_PROMPT,
25 PdfOcrChunk,
26 build_vision_messages,
27 pdf_page_count,
28 rasterize_pdf,
29)
31log = logging.getLogger(__name__)
34def _make_abort_callback(abort_flag: Any) -> Any:
35 """Return a llama-cpp abort_callback bound to the shared mp.Value flag."""
37 def _callback(_user_data: Any = None) -> bool:
38 return bool(abort_flag.value)
40 return _callback
43class _VisionSession:
44 """Lazy-loaded vision Llama, kept alive for the worker's lifetime."""
46 def __init__(self, role_config: RoleConfig, abort_flag: Any) -> None:
47 self._role_config = role_config
48 self._abort_flag = abort_flag
49 self._llm: Any = None
50 self._model_path: str = ""
52 def ocr(self, *, png_bytes: bytes, prompt: str, model: str | None) -> str:
53 """Run OCR on one image, loading the model on first use."""
54 llm = self._ensure_loaded(model)
55 messages = build_vision_messages(prompt or OCR_PROMPT, png_bytes)
56 start = time.monotonic()
57 response = llm.create_chat_completion(messages=messages, stream=False)
58 text = _extract_vision_content(response)
59 usage = response.get("usage", {}) if isinstance(response, dict) else {}
60 if not isinstance(usage, dict):
61 usage = {}
62 log.info(
63 "vision_ocr wall=%.1fs prompt_tokens=%s completion_tokens=%s chars=%d",
64 time.monotonic() - start,
65 usage.get("prompt_tokens"),
66 usage.get("completion_tokens"),
67 len(text),
68 )
69 return text
71 def _ensure_loaded(self, model_override: str | None) -> Any:
72 from lilbee.providers.llama_cpp.provider import resolve_model_path
73 from lilbee.providers.mtmd_backend import load_vision_llama
75 target_path = (
76 resolve_model_path(model_override) if model_override else self._role_config.model_path
77 )
78 target_str = str(target_path)
79 if self._llm is None or target_str != self._model_path:
80 self._close_model()
81 # The abort flag lives in shared memory (mp.Value), so the
82 # callback bound here lets the parent's pool.cancel() reach
83 # llama-cpp's vision inference loop in this subprocess.
84 self._llm = load_vision_llama(
85 target_path,
86 abort_callback_override=_make_abort_callback(self._abort_flag),
87 )
88 self._model_path = target_str
89 return self._llm
91 def _close_model(self) -> None:
92 if self._llm is not None:
93 with contextlib.suppress(Exception):
94 self._llm.close()
95 self._llm = None
97 def close(self) -> None:
98 """Release the loaded model. Idempotent."""
99 self._close_model()
102def _extract_vision_content(response: Any) -> str:
103 """Pull the OCR text out of one llama-cpp vision response.
105 Mirrors the chat path's defensive walk so a malformed response
106 surfaces as a typed :class:`TypeError` we can serialize, instead of
107 a raw :class:`KeyError` / :class:`IndexError` deep in the worker.
108 """
109 if not isinstance(response, dict):
110 raise TypeError(f"vision response must be dict, got {type(response).__name__}")
111 choices = response.get("choices")
112 if not isinstance(choices, list) or not choices:
113 raise TypeError("vision response missing 'choices' list")
114 first = choices[0]
115 if not isinstance(first, dict):
116 raise TypeError(f"vision choices[0] must be dict, got {type(first).__name__}")
117 message = first.get("message")
118 if not isinstance(message, dict):
119 raise TypeError("vision choices[0].message missing or not dict")
120 content = message.get("content")
121 return content if isinstance(content, str) else ""
124def _handle_vision(reply: Reply, payload: Any, state: WorkerLoopState) -> None:
125 """Run one vision OCR request and send the typed reply (or error)."""
126 if not isinstance(payload, VisionRequest):
127 try:
128 raise TypeError(
129 f"vision_ocr payload must be VisionRequest, got {type(payload).__name__}"
130 )
131 except TypeError as exc:
132 reply.send(WireKind.ERROR, _serialize_exception(exc))
133 return
134 if not isinstance(payload.png_bytes, (bytes, bytearray)):
135 try:
136 raise TypeError("vision_ocr payload.png_bytes must be bytes")
137 except TypeError as exc:
138 reply.send(WireKind.ERROR, _serialize_exception(exc))
139 return
140 session: _VisionSession = state.session
141 try:
142 text = session.ocr(
143 png_bytes=bytes(payload.png_bytes),
144 prompt=payload.prompt,
145 model=payload.model,
146 )
147 except Exception as exc:
148 reply.send(WireKind.ERROR, _serialize_exception(exc))
149 return
150 reply.send(WireKind.RESULT, text)
153def _handle_pdf_ocr(reply: Reply, payload: Any, state: WorkerLoopState) -> None:
154 """Stream multi-page vision PDF OCR results, one chunk per page, then stream_end.
156 Iterates rasterised PDF pages and OCRs each via the loaded vision
157 Llama. ``payload.backend`` must be ``"vision"``; tesseract is run
158 inline by the ingest caller, not here, because tesseract has no
159 shared model state and pool routing buys it nothing.
160 """
161 if not isinstance(payload, PdfOcrRequest):
162 try:
163 raise TypeError(f"pdf_ocr payload must be PdfOcrRequest, got {type(payload).__name__}")
164 except TypeError as exc:
165 reply.send(WireKind.ERROR, _serialize_exception(exc))
166 return
167 # ``payload.backend`` is typed ``Literal["vision"]`` so any other
168 # value is a type-system regression on the parent side; trust the
169 # contract and validate only the payload shape above.
170 session: _VisionSession = state.session
171 try:
172 path = Path(payload.path)
173 total = pdf_page_count(path)
174 model_override = payload.model or None
175 for idx, png_bytes in rasterize_pdf(path):
176 text = session.ocr(
177 png_bytes=bytes(png_bytes),
178 prompt=OCR_PROMPT,
179 model=model_override,
180 )
181 # 1-based page index matches how the rest of lilbee numbers
182 # pages (PageText, ExtractEvent, etc.). Total ships in every
183 # chunk so consumers don't need a separate header frame.
184 reply.send(WireKind.STREAM_CHUNK, PdfOcrChunk(page=idx + 1, total=total, text=text))
185 except Exception as exc:
186 reply.send(WireKind.ERROR, _serialize_exception(exc))
187 return
188 reply.send(WireKind.STREAM_END, None)
191def vision_worker_main(
192 data_conn: Any, health_conn: Any, abort_flag: Any, role_config: RoleConfig
193) -> None:
194 """Vision-OCR worker entrypoint: load llama-cpp lazily, serve until shutdown."""
195 run_worker(
196 data_conn,
197 health_conn,
198 abort_flag,
199 role_config,
200 session_factory=_VisionSession,
201 kind_handlers={
202 WireKind.VISION: _handle_vision,
203 WireKind.PDF_OCR: _handle_pdf_ocr,
204 },
205 )
208__all__ = ["vision_worker_main"]