Coverage for src / lilbee / providers / worker / vision_worker.py: 100%

106 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Long-lived vision-OCR worker subprocess body. 

2 

3Hosts both single-image OCR (``WireKind.VISION``, used by the live wiki / 

4catalog flows) and multi-page PDF vision OCR (``WireKind.PDF_OCR``, used by 

5the ingest pipeline). The PDF path streams one chunk per page so 

6subscribers see incremental progress. Tesseract OCR has no shared model 

7state and runs inline via ``asyncio.to_thread`` in the ingest caller, 

8not through this worker. 

9""" 

10 

11from __future__ import annotations 

12 

13import contextlib 

14import logging 

15import time 

16from pathlib import Path 

17from typing import Any 

18 

19from lilbee.providers.worker.transport import PdfOcrRequest, RoleConfig, VisionRequest 

20from lilbee.providers.worker.transport_pipe import _serialize_exception 

21from lilbee.providers.worker.wire_kinds import WireKind 

22from lilbee.providers.worker.worker_runtime import Reply, WorkerLoopState, run_worker 

23from lilbee.vision import ( 

24 OCR_PROMPT, 

25 PdfOcrChunk, 

26 build_vision_messages, 

27 pdf_page_count, 

28 rasterize_pdf, 

29) 

30 

31log = logging.getLogger(__name__) 

32 

33 

34def _make_abort_callback(abort_flag: Any) -> Any: 

35 """Return a llama-cpp abort_callback bound to the shared mp.Value flag.""" 

36 

37 def _callback(_user_data: Any = None) -> bool: 

38 return bool(abort_flag.value) 

39 

40 return _callback 

41 

42 

43class _VisionSession: 

44 """Lazy-loaded vision Llama, kept alive for the worker's lifetime.""" 

45 

46 def __init__(self, role_config: RoleConfig, abort_flag: Any) -> None: 

47 self._role_config = role_config 

48 self._abort_flag = abort_flag 

49 self._llm: Any = None 

50 self._model_path: str = "" 

51 

52 def ocr(self, *, png_bytes: bytes, prompt: str, model: str | None) -> str: 

53 """Run OCR on one image, loading the model on first use.""" 

54 llm = self._ensure_loaded(model) 

55 messages = build_vision_messages(prompt or OCR_PROMPT, png_bytes) 

56 start = time.monotonic() 

57 response = llm.create_chat_completion(messages=messages, stream=False) 

58 text = _extract_vision_content(response) 

59 usage = response.get("usage", {}) if isinstance(response, dict) else {} 

60 if not isinstance(usage, dict): 

61 usage = {} 

62 log.info( 

63 "vision_ocr wall=%.1fs prompt_tokens=%s completion_tokens=%s chars=%d", 

64 time.monotonic() - start, 

65 usage.get("prompt_tokens"), 

66 usage.get("completion_tokens"), 

67 len(text), 

68 ) 

69 return text 

70 

71 def _ensure_loaded(self, model_override: str | None) -> Any: 

72 from lilbee.providers.llama_cpp.provider import resolve_model_path 

73 from lilbee.providers.mtmd_backend import load_vision_llama 

74 

75 target_path = ( 

76 resolve_model_path(model_override) if model_override else self._role_config.model_path 

77 ) 

78 target_str = str(target_path) 

79 if self._llm is None or target_str != self._model_path: 

80 self._close_model() 

81 # The abort flag lives in shared memory (mp.Value), so the 

82 # callback bound here lets the parent's pool.cancel() reach 

83 # llama-cpp's vision inference loop in this subprocess. 

84 self._llm = load_vision_llama( 

85 target_path, 

86 abort_callback_override=_make_abort_callback(self._abort_flag), 

87 ) 

88 self._model_path = target_str 

89 return self._llm 

90 

91 def _close_model(self) -> None: 

92 if self._llm is not None: 

93 with contextlib.suppress(Exception): 

94 self._llm.close() 

95 self._llm = None 

96 

97 def close(self) -> None: 

98 """Release the loaded model. Idempotent.""" 

99 self._close_model() 

100 

101 

102def _extract_vision_content(response: Any) -> str: 

103 """Pull the OCR text out of one llama-cpp vision response. 

104 

105 Mirrors the chat path's defensive walk so a malformed response 

106 surfaces as a typed :class:`TypeError` we can serialize, instead of 

107 a raw :class:`KeyError` / :class:`IndexError` deep in the worker. 

108 """ 

109 if not isinstance(response, dict): 

110 raise TypeError(f"vision response must be dict, got {type(response).__name__}") 

111 choices = response.get("choices") 

112 if not isinstance(choices, list) or not choices: 

113 raise TypeError("vision response missing 'choices' list") 

114 first = choices[0] 

115 if not isinstance(first, dict): 

116 raise TypeError(f"vision choices[0] must be dict, got {type(first).__name__}") 

117 message = first.get("message") 

118 if not isinstance(message, dict): 

119 raise TypeError("vision choices[0].message missing or not dict") 

120 content = message.get("content") 

121 return content if isinstance(content, str) else "" 

122 

123 

124def _handle_vision(reply: Reply, payload: Any, state: WorkerLoopState) -> None: 

125 """Run one vision OCR request and send the typed reply (or error).""" 

126 if not isinstance(payload, VisionRequest): 

127 try: 

128 raise TypeError( 

129 f"vision_ocr payload must be VisionRequest, got {type(payload).__name__}" 

130 ) 

131 except TypeError as exc: 

132 reply.send(WireKind.ERROR, _serialize_exception(exc)) 

133 return 

134 if not isinstance(payload.png_bytes, (bytes, bytearray)): 

135 try: 

136 raise TypeError("vision_ocr payload.png_bytes must be bytes") 

137 except TypeError as exc: 

138 reply.send(WireKind.ERROR, _serialize_exception(exc)) 

139 return 

140 session: _VisionSession = state.session 

141 try: 

142 text = session.ocr( 

143 png_bytes=bytes(payload.png_bytes), 

144 prompt=payload.prompt, 

145 model=payload.model, 

146 ) 

147 except Exception as exc: 

148 reply.send(WireKind.ERROR, _serialize_exception(exc)) 

149 return 

150 reply.send(WireKind.RESULT, text) 

151 

152 

153def _handle_pdf_ocr(reply: Reply, payload: Any, state: WorkerLoopState) -> None: 

154 """Stream multi-page vision PDF OCR results, one chunk per page, then stream_end. 

155 

156 Iterates rasterised PDF pages and OCRs each via the loaded vision 

157 Llama. ``payload.backend`` must be ``"vision"``; tesseract is run 

158 inline by the ingest caller, not here, because tesseract has no 

159 shared model state and pool routing buys it nothing. 

160 """ 

161 if not isinstance(payload, PdfOcrRequest): 

162 try: 

163 raise TypeError(f"pdf_ocr payload must be PdfOcrRequest, got {type(payload).__name__}") 

164 except TypeError as exc: 

165 reply.send(WireKind.ERROR, _serialize_exception(exc)) 

166 return 

167 # ``payload.backend`` is typed ``Literal["vision"]`` so any other 

168 # value is a type-system regression on the parent side; trust the 

169 # contract and validate only the payload shape above. 

170 session: _VisionSession = state.session 

171 try: 

172 path = Path(payload.path) 

173 total = pdf_page_count(path) 

174 model_override = payload.model or None 

175 for idx, png_bytes in rasterize_pdf(path): 

176 text = session.ocr( 

177 png_bytes=bytes(png_bytes), 

178 prompt=OCR_PROMPT, 

179 model=model_override, 

180 ) 

181 # 1-based page index matches how the rest of lilbee numbers 

182 # pages (PageText, ExtractEvent, etc.). Total ships in every 

183 # chunk so consumers don't need a separate header frame. 

184 reply.send(WireKind.STREAM_CHUNK, PdfOcrChunk(page=idx + 1, total=total, text=text)) 

185 except Exception as exc: 

186 reply.send(WireKind.ERROR, _serialize_exception(exc)) 

187 return 

188 reply.send(WireKind.STREAM_END, None) 

189 

190 

191def vision_worker_main( 

192 data_conn: Any, health_conn: Any, abort_flag: Any, role_config: RoleConfig 

193) -> None: 

194 """Vision-OCR worker entrypoint: load llama-cpp lazily, serve until shutdown.""" 

195 run_worker( 

196 data_conn, 

197 health_conn, 

198 abort_flag, 

199 role_config, 

200 session_factory=_VisionSession, 

201 kind_handlers={ 

202 WireKind.VISION: _handle_vision, 

203 WireKind.PDF_OCR: _handle_pdf_ocr, 

204 }, 

205 ) 

206 

207 

208__all__ = ["vision_worker_main"]