Coverage for src / lilbee / providers / llama_cpp / provider.py: 100%

444 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Llama.cpp provider: class, model loader, and path resolution.""" 

2 

3from __future__ import annotations 

4 

5import contextlib 

6import logging 

7import threading 

8from collections.abc import AsyncIterator, Callable 

9from dataclasses import dataclass 

10from pathlib import Path 

11from typing import Any, Literal, cast, overload 

12 

13from lilbee.app.services import get_services 

14from lilbee.catalog import is_rerank_ref 

15from lilbee.core.config import DEFAULT_NUM_CTX, cfg 

16from lilbee.core.config.enums import KV_CACHE_TYPE_BYTES, KvCacheType 

17from lilbee.providers.base import ClosableIterator, LLMProvider, ProviderError, filter_options 

18from lilbee.providers.llama_cpp.abort_signal import abort_callback, clear_abort 

19from lilbee.providers.llama_cpp.gguf_meta import ( 

20 find_mmproj_for_model, 

21 read_gguf_metadata, 

22 train_ctx_from_meta, 

23) 

24from lilbee.providers.llama_cpp.log_dispatch import ( 

25 import_llama_cpp, 

26 install_llama_log_handler, 

27 suppress_native_stderr, 

28) 

29from lilbee.providers.model_cache import ( 

30 LoaderMode, 

31 compute_dynamic_ctx, 

32 get_available_memory, 

33 kv_bytes_per_token, 

34) 

35from lilbee.providers.worker.chat_worker import chat_worker_main 

36from lilbee.providers.worker.embed_worker import embed_worker_main 

37from lilbee.providers.worker.pool import PoolRuntime, RoleAccessor 

38from lilbee.providers.worker.rerank_worker import rerank_worker_main 

39from lilbee.providers.worker.transport import ( 

40 ChatRequest, 

41 OcrBackend, 

42 PdfOcrRequest, 

43 RerankPayload, 

44 RoleConfig, 

45 VisionRequest, 

46 WorkerRole, 

47) 

48from lilbee.providers.worker.transport_pipe import WorkerCrashError, WorkerError 

49from lilbee.providers.worker.vision_worker import vision_worker_main 

50from lilbee.providers.worker.wire_kinds import WireKind 

51from lilbee.runtime.progress import EventType, ExtractEvent 

52from lilbee.vision import PageText, PdfOcrChunk, pdf_page_count 

53 

54log = logging.getLogger(__name__) 

55 

56# Vision OCR sentinel used when no per-call timeout and no ``cfg.ocr_timeout`` 

57# is set. 24h is effectively "no cap" for the round-trip wait loop. 

58_VISION_NO_CAP_TIMEOUT_S = 86_400.0 

59 

60_LLAMA_CONTEXT_PATCH_LOCK = threading.Lock() 

61"""Serialises overlapping ``_llama_n_seq_max`` callers inside one process. 

62 

63The shim mutates ``llama_cpp.internals.LlamaContext.__init__`` globally 

64while the with-block is open. Worker subprocesses each load one model 

65serially today, but the lock keeps the contract safe if a future caller 

66loads two models concurrently. 

67""" 

68 

69 

70@contextlib.contextmanager 

71def _llama_n_seq_max(n_seq_max: int) -> Any: 

72 """Set ``context_params.n_seq_max`` on the next ``LlamaContext`` constructed. 

73 

74 Workaround for llama-cpp-python upstream issue #2051 (``n_seq_max`` 

75 not exposed as a Llama kwarg). See ``docs/architecture.md`` for the 

76 full rationale and the upstream-fix removal hint. 

77 """ 

78 from llama_cpp import internals 

79 

80 with _LLAMA_CONTEXT_PATCH_LOCK: 

81 original = internals.LlamaContext.__init__ 

82 

83 def patched(self: Any, *, model: Any, params: Any, verbose: bool) -> None: 

84 params.n_seq_max = n_seq_max 

85 original(self, model=model, params=params, verbose=verbose) 

86 

87 internals.LlamaContext.__init__ = patched # type: ignore[method-assign,assignment] 

88 try: 

89 yield 

90 finally: 

91 internals.LlamaContext.__init__ = original # type: ignore[method-assign] 

92 

93 

94# Cap on tokens drained during ``_PoolChatStreamIterator.close()`` after a 

95# mid-stream cancel. A runaway model (Qwen3-0.6B stuck in a never-closing 

96# ``<think>`` loop) would otherwise block close() indefinitely. 

97_CHAT_STREAM_DRAIN_CAP = 1024 

98 

99# Chat-load OOM retry knobs. The OOM wrapper halves ``n_ctx`` (rounded down to 

100# the next ``_CTX_QUANTUM`` multiple) up to ``_MAX_OOM_RETRIES`` times before 

101# raising. ``_CTX_FLOOR`` is the smallest ``n_ctx`` we'll attempt. 

102_MAX_OOM_RETRIES = 2 

103_CTX_QUANTUM = 256 

104_CTX_FLOOR = 512 

105 

106# Sentinel passed to ``llama-cpp-python`` for "offload all layers". 

107_N_GPU_LAYERS_AUTO = -1 

108 

109 

110# Settings baked into Llama() at load time, or whose change picks a 

111# different model file. Sampling params are read per-call and excluded. 

112LOAD_AFFECTING_KEYS = frozenset( 

113 { 

114 "num_ctx", 

115 "chat_model", 

116 "embedding_model", 

117 "vision_model", 

118 "reranker_model", 

119 } 

120) 

121 

122# Subset of LOAD_AFFECTING_KEYS whose change is observed by the worker on the 

123# next per-call ``request.model`` (chat_worker / vision_worker check the path 

124# in ``_ensure_loaded`` and reload in place). For these, the parent does not 

125# need to release the pool role; the next call swaps the model inside the live 

126# worker, saving the 1-3 s spawn cost. 

127PER_CALL_RELOADABLE_KEYS = frozenset({"chat_model", "vision_model"}) 

128 

129 

130class LlamaCppProvider(LLMProvider): 

131 """Provider backed by llama-cpp-python for local GGUF model inference.""" 

132 

133 def __init__(self) -> None: 

134 self._pool_lock = threading.Lock() 

135 self._registered_roles: set[WorkerRole] = set() 

136 

137 @staticmethod 

138 def _worker_error_message(role_label: str, exc: WorkerError) -> str: 

139 """Render a user-facing message that names the role and points at the log. 

140 

141 ``WorkerCrashError`` already embeds the log path in its message; for 

142 plain ``WorkerError`` (the worker reported an exception or returned 

143 a malformed reply) the surfaced text is the worker's exception 

144 repr so the user sees enough to file a bug report. 

145 """ 

146 detail = str(exc) 

147 if detail.endswith("."): # the wrapper supplies its own sentence-final period 

148 detail = detail[:-1] 

149 if isinstance(exc, WorkerCrashError): 

150 return f"{role_label} worker exited unexpectedly. {detail}. Please try again." 

151 return f"{role_label} worker reported an error: {detail}. Please try again." 

152 

153 def _pool_runtime(self) -> PoolRuntime: 

154 """Return the Services-owned :class:`PoolRuntime`, starting it lazily.""" 

155 runtime = get_services().pool_runtime 

156 runtime.start() 

157 return runtime 

158 

159 def _get_pool_accessor( 

160 self, 

161 role: WorkerRole, 

162 worker_main: Any, 

163 config_factory: Callable[[], RoleConfig], 

164 ) -> RoleAccessor: 

165 """Register *role* on the Services pool the first time it is used. 

166 

167 Subsequent calls return the same accessor without touching the 

168 pool state. Registration is gated by ``self._pool_lock`` so two 

169 concurrent first-callers do not race to register the role twice. 

170 """ 

171 pool = get_services().worker_pool 

172 with self._pool_lock: 

173 if role not in self._registered_roles: 

174 pool.register(role, worker_main, config_factory) 

175 self._registered_roles.add(role) 

176 return pool.accessor(role) 

177 

178 def embed(self, texts: list[str]) -> list[list[float]]: 

179 """Embed texts via the persistent pool worker. 

180 

181 Worker crashes and timeouts surface as :class:`ProviderError`; 

182 the pool respawns the embed role lazily on the next call. 

183 """ 

184 accessor = self._get_pool_accessor( 

185 WorkerRole.EMBED, embed_worker_main, _make_role_config_factory(WorkerRole.EMBED) 

186 ) 

187 runtime = self._pool_runtime() 

188 try: 

189 result = runtime.run_sync( 

190 accessor.call(WireKind.EMBED, texts, timeout=cfg.worker_pool_call_timeout_s), 

191 timeout=cfg.worker_pool_call_timeout_s, 

192 ) 

193 if not isinstance(result, list): 

194 raise WorkerError( 

195 "ProtocolError", 

196 f"Pool embed returned {type(result).__name__}, expected list[list[float]].", 

197 "", 

198 ) 

199 except WorkerError as exc: 

200 raise ProviderError( 

201 self._worker_error_message("Embedding", exc), 

202 provider="llama-cpp", 

203 ) from exc 

204 except TimeoutError as exc: 

205 raise ProviderError( 

206 "Embedding worker timed out. Please try again.", 

207 provider="llama-cpp", 

208 ) from exc 

209 return result 

210 

211 def rerank(self, query: str, candidates: list[str]) -> list[float]: 

212 """Score *candidates* by relevance to *query* via the pool worker.""" 

213 if not candidates: 

214 return [] 

215 accessor = self._get_pool_accessor( 

216 WorkerRole.RERANK, rerank_worker_main, _make_role_config_factory(WorkerRole.RERANK) 

217 ) 

218 runtime = self._pool_runtime() 

219 try: 

220 result = runtime.run_sync( 

221 accessor.call( 

222 WireKind.RERANK, 

223 RerankPayload(query=query, candidates=candidates), 

224 timeout=cfg.worker_pool_call_timeout_s, 

225 ), 

226 timeout=cfg.worker_pool_call_timeout_s, 

227 ) 

228 if not isinstance(result, list): 

229 raise WorkerError( 

230 "ProtocolError", 

231 f"Pool rerank returned {type(result).__name__}, expected list[float].", 

232 "", 

233 ) 

234 except WorkerError as exc: 

235 raise ProviderError( 

236 self._worker_error_message("Rerank", exc), 

237 provider="llama-cpp", 

238 ) from exc 

239 except TimeoutError as exc: 

240 raise ProviderError( 

241 "Rerank worker timed out. Please try again.", 

242 provider="llama-cpp", 

243 ) from exc 

244 return result 

245 

246 def supports_rerank(self) -> bool: 

247 """llama-cpp can rerank iff llama-cpp-python exposes the rank pooling type.""" 

248 return _llama_cpp_has_rank_pooling() 

249 

250 def vision_ocr( 

251 self, png_bytes: bytes, model: str, prompt: str = "", *, timeout: float | None = None 

252 ) -> str: 

253 """Run vision OCR via the persistent pool worker.""" 

254 accessor = self._get_pool_accessor( 

255 WorkerRole.VISION, vision_worker_main, _make_role_config_factory(WorkerRole.VISION) 

256 ) 

257 runtime = self._pool_runtime() 

258 budget = self._vision_call_budget(timeout) 

259 request = VisionRequest(png_bytes=png_bytes, prompt=prompt, model=model or None) 

260 try: 

261 result = runtime.run_sync( 

262 accessor.call(WireKind.VISION, request, timeout=budget), 

263 timeout=budget, 

264 ) 

265 if not isinstance(result, str): 

266 raise WorkerError( 

267 "ProtocolError", 

268 f"Pool vision_ocr returned {type(result).__name__}, expected str.", 

269 "", 

270 ) 

271 except WorkerError as exc: 

272 raise ProviderError( 

273 self._worker_error_message("Vision", exc), 

274 provider="llama-cpp", 

275 ) from exc 

276 except TimeoutError as exc: 

277 raise ProviderError( 

278 "Vision worker timed out. Please try again.", 

279 provider="llama-cpp", 

280 ) from exc 

281 return result 

282 

283 @staticmethod 

284 def _vision_call_budget(timeout: float | None) -> float: 

285 """Wall-clock budget for one vision_ocr call (per-call > cfg.ocr_timeout > no cap).""" 

286 effective = timeout if timeout is not None else cfg.ocr_timeout 

287 return float(effective) if effective and effective > 0 else _VISION_NO_CAP_TIMEOUT_S 

288 

289 def pdf_ocr( 

290 self, 

291 path: Path, 

292 *, 

293 backend: OcrBackend, 

294 model: str = "", 

295 per_page_timeout_s: float | None = None, 

296 quiet: bool = True, 

297 on_progress: Callable[..., None] | None = None, 

298 ) -> list[PageText]: 

299 """Run multi-page vision PDF OCR via the persistent vision worker. 

300 

301 ``per_page_timeout_s`` is *per page*. The total wall-clock cap on 

302 the streamed drain is ``pages * per_page + cfg.vision_load_budget_s`` 

303 (load grace), so a 100-page scan with a 60 s per-page budget gets 

304 ~6000 s + load, not 60 s for the whole document. 

305 """ 

306 accessor = self._get_pool_accessor( 

307 WorkerRole.VISION, vision_worker_main, _make_role_config_factory(WorkerRole.VISION) 

308 ) 

309 runtime = self._pool_runtime() 

310 budget = self._pdf_drain_budget(path, per_page_timeout_s) 

311 del quiet # accepted for Protocol parity; worker has no Rich progress to suppress. 

312 request = PdfOcrRequest( 

313 path=str(path), 

314 backend=backend, 

315 model=model, 

316 ) 

317 progress = on_progress 

318 

319 async def _drain() -> list[PageText]: 

320 pages: list[PageText] = [] 

321 stream = cast(AsyncIterator[Any], accessor.stream(WireKind.PDF_OCR, request)) 

322 async for frame in stream: 

323 if not isinstance(frame, PdfOcrChunk): 

324 raise ProviderError( 

325 f"PDF OCR worker streamed unexpected frame type {type(frame).__name__}.", 

326 provider="llama-cpp", 

327 ) 

328 pages.append(PageText(frame.page, frame.text)) 

329 if progress is not None: 

330 progress( 

331 EventType.EXTRACT, 

332 ExtractEvent(file=path.name, page=frame.page, total_pages=frame.total), 

333 ) 

334 return pages 

335 

336 try: 

337 return runtime.run_sync(_drain(), timeout=budget) 

338 except WorkerError as exc: 

339 raise ProviderError( 

340 self._worker_error_message("PDF OCR", exc), 

341 provider="llama-cpp", 

342 ) from exc 

343 except TimeoutError as exc: 

344 raise ProviderError( 

345 "PDF OCR worker timed out. Please try again.", 

346 provider="llama-cpp", 

347 ) from exc 

348 

349 def _pdf_drain_budget(self, path: Path, per_page_timeout_s: float | None) -> float: 

350 """Total drain timeout = page_count * per_page + vision_load_budget_s.""" 

351 if not per_page_timeout_s or per_page_timeout_s <= 0: 

352 return _VISION_NO_CAP_TIMEOUT_S 

353 try: 

354 pages = pdf_page_count(path) 

355 except Exception: 

356 # If we can't probe pages upfront the worker will still try, 

357 # but we lose the precise budget; fall back to no-cap so the 

358 # parent doesn't kill a valid run on a probe failure. 

359 return _VISION_NO_CAP_TIMEOUT_S 

360 return float(pages) * per_page_timeout_s + cfg.vision_load_budget_s 

361 

362 @overload 

363 def chat( 

364 self, 

365 messages: list[dict[str, str]], 

366 *, 

367 stream: Literal[False] = False, 

368 options: dict[str, Any] | None = None, 

369 model: str | None = None, 

370 ) -> str: ... 

371 

372 @overload 

373 def chat( 

374 self, 

375 messages: list[dict[str, str]], 

376 *, 

377 stream: Literal[True], 

378 options: dict[str, Any] | None = None, 

379 model: str | None = None, 

380 ) -> ClosableIterator[str]: ... 

381 

382 def chat( 

383 self, 

384 messages: list[dict[str, str]], 

385 *, 

386 stream: bool = False, 

387 options: dict[str, Any] | None = None, 

388 model: str | None = None, 

389 ) -> str | ClosableIterator[str]: 

390 """Chat completion via the persistent pool worker. 

391 

392 Streaming returns a :class:`ClosableIterator[str]` whose 

393 ``close()`` flips the worker's abort flag so in-flight generation 

394 drains cleanly. Non-streaming returns the assembled assistant text. 

395 """ 

396 accessor = self._get_pool_accessor( 

397 WorkerRole.CHAT, chat_worker_main, _make_role_config_factory(WorkerRole.CHAT) 

398 ) 

399 runtime = self._pool_runtime() 

400 accessor.clear_abort() # honor mid-stream cancels from the previous turn 

401 request = ChatRequest( 

402 messages=messages, 

403 stream=stream, 

404 options=self._chat_kwargs_from_options(options) or None, 

405 model=model, 

406 ) 

407 if stream: 

408 return _PoolChatStreamIterator( 

409 runtime=runtime, 

410 accessor=accessor, 

411 async_iter=accessor.stream(WireKind.CHAT, request), 

412 ) 

413 try: 

414 result = runtime.run_sync( 

415 accessor.call(WireKind.CHAT, request, timeout=cfg.worker_pool_call_timeout_s), 

416 timeout=cfg.worker_pool_call_timeout_s, 

417 ) 

418 if not isinstance(result, str): 

419 raise WorkerError( 

420 "ProtocolError", 

421 f"Pool chat returned {type(result).__name__}, expected str.", 

422 "", 

423 ) 

424 except WorkerError as exc: 

425 raise ProviderError( 

426 self._worker_error_message("Chat", exc), 

427 provider="llama-cpp", 

428 ) from exc 

429 except TimeoutError as exc: 

430 raise ProviderError( 

431 "Chat worker timed out. Please try again.", 

432 provider="llama-cpp", 

433 ) from exc 

434 return result 

435 

436 @staticmethod 

437 def _chat_kwargs_from_options(options: dict[str, Any] | None) -> dict[str, Any]: 

438 """Translate user-facing options into llama-cpp create_chat_completion kwargs.""" 

439 if not options: 

440 return {} 

441 filtered = filter_options(options) 

442 if "num_predict" in filtered: 

443 filtered["max_tokens"] = filtered.pop("num_predict") 

444 filtered.pop("num_ctx", None) # model-load param, not per-call 

445 return filtered 

446 

447 def list_models(self) -> list[str]: 

448 """List installed models from registry.""" 

449 registry = get_services().registry 

450 return sorted(m.ref for m in registry.list_installed()) 

451 

452 def list_chat_models(self, provider: str) -> list[str]: 

453 """llama-cpp has no frontier-provider catalog; always ``[]``.""" 

454 return [] 

455 

456 def pull_model(self, model: str, *, on_progress: Callable[..., Any] | None = None) -> None: 

457 """Not supported directly: ``lilbee.catalog`` handles downloads.""" 

458 raise NotImplementedError( 

459 f"llama-cpp provider cannot pull model {model!r}. " 

460 "Download GGUF files manually or use the catalog." 

461 ) 

462 

463 def show_model(self, model: str) -> dict[str, Any] | None: 

464 """Return model metadata from GGUF headers.""" 

465 try: 

466 path = resolve_model_path(model) 

467 except ProviderError: 

468 return None 

469 return read_gguf_metadata(path) 

470 

471 def get_capabilities(self, model: str) -> list[str]: 

472 """Detect capabilities from local GGUF files. 

473 

474 Rerank models return ``["rerank"]``; cross-encoder GGUFs cannot 

475 generate text. Other models report ``"completion"``, plus 

476 ``"vision"`` when an mmproj sidecar is present. 

477 """ 

478 if _is_rerank_model(model): 

479 return ["rerank"] 

480 caps: list[str] = ["completion"] 

481 try: 

482 path = resolve_model_path(model) 

483 except ProviderError: 

484 log.debug("resolve_model_path failed for %s", model, exc_info=True) 

485 return caps 

486 try: 

487 find_mmproj_for_model(path) 

488 caps.append("vision") 

489 except ProviderError: 

490 log.debug("no mmproj for %s", model, exc_info=True) 

491 return caps 

492 

493 def warm_up_pool(self) -> None: 

494 """Register roles for every configured model. Idempotent. 

495 

496 Called by ``Services`` when ``cfg.worker_pool_eager_start`` is on so 

497 ``WorkerPool.start_eager()`` has roles to spawn. Roles whose model is 

498 unset are skipped; this lets a setup with only ``chat_model`` + 

499 ``embedding_model`` configured eager-start exactly those two and not 

500 pay rerank or vision spawn cost. 

501 """ 

502 for role, _spec in _ROLE_SPECS.items(): 

503 if not _is_role_configured(role): 

504 continue 

505 entrypoint = _ROLE_ENTRYPOINTS[role] 

506 self._get_pool_accessor(role, entrypoint, _make_role_config_factory(role)) 

507 

508 def shutdown(self) -> None: 

509 """Drop pool registrations so a follow-up provider can re-register cleanly.""" 

510 self._release_pool_roles() 

511 

512 def _release_pool_roles(self) -> None: 

513 """Drop our registrations on the Services pool so the next call respawns. 

514 

515 Safe even when Services has not yet been built (early shutdown 

516 on import-time failure). Holds ``self._pool_lock`` so a concurrent 

517 ``_get_pool_accessor`` does not race the role removal. 

518 """ 

519 with self._pool_lock: 

520 roles = tuple(self._registered_roles) 

521 self._registered_roles.clear() 

522 if not roles: 

523 return 

524 from lilbee.providers.worker.pool import PoolShutdownError 

525 

526 services = get_services() 

527 runtime = services.pool_runtime 

528 for role in roles: 

529 try: 

530 runtime.run_sync(services.worker_pool.release(role), timeout=10.0) 

531 except PoolShutdownError: 

532 # Pool already shut down (atexit ordering during a CLI exit 

533 # tears down the pool runtime before this provider). Nothing 

534 # to release; silent no-op. 

535 pass 

536 except (TimeoutError, RuntimeError, OSError) as exc: 

537 log.warning("Pool release of role=%s raised %s", role, exc) 

538 

539 def invalidate_load_cache(self, model_path: Path | None = None) -> None: 

540 """Drop the pool's per-role workers so the next call respawns with current settings. 

541 

542 The ``model_path`` argument is accepted for protocol parity with 

543 other providers but does not narrow the scope: workers reload all 

544 their roles on respawn anyway. 

545 """ 

546 del model_path 

547 self._release_pool_roles() 

548 

549 

550class _PoolChatStreamIterator: 

551 """Sync facade over an async chat-stream iterator from the worker pool. 

552 

553 Each ``__next__`` submits one ``__anext__`` to the pool's runtime 

554 loop and blocks for the result. ``close()`` flips the worker's abort 

555 flag so any in-flight generation stops at the next token-tick; 

556 in-flight chunks already in the pipe still drain. 

557 """ 

558 

559 def __init__( 

560 self, 

561 *, 

562 runtime: PoolRuntime, 

563 accessor: RoleAccessor, 

564 async_iter: Any, 

565 ) -> None: 

566 self._runtime = runtime 

567 self._accessor = accessor 

568 self._async_iter = async_iter 

569 self._exhausted = False 

570 

571 def __iter__(self) -> _PoolChatStreamIterator: 

572 return self 

573 

574 def __next__(self) -> str: 

575 if self._exhausted: 

576 raise StopIteration 

577 try: 

578 chunk: str = self._runtime.run_sync( 

579 self._async_iter.__anext__(), 

580 timeout=cfg.worker_pool_call_timeout_s, 

581 ) 

582 return chunk 

583 except StopAsyncIteration: 

584 self._exhausted = True 

585 raise StopIteration from None 

586 except WorkerError as exc: 

587 # Mid-stream worker crashes propagate as ProviderError so the 

588 # streaming path matches the non-streaming contract. 

589 self._exhausted = True 

590 raise ProviderError( 

591 LlamaCppProvider._worker_error_message("Chat", exc), 

592 provider="llama-cpp", 

593 ) from exc 

594 except TimeoutError as exc: 

595 self._exhausted = True 

596 raise ProviderError( 

597 "Chat worker timed out mid-stream. Please try again.", 

598 provider="llama-cpp", 

599 ) from exc 

600 

601 def close(self) -> None: 

602 """Cancel mid-stream and drain remaining tokens from the pipe. 

603 

604 Drain is bounded by ``_CHAT_STREAM_DRAIN_CAP`` so a stuck 

605 worker cannot block close() indefinitely; once the cap fires we 

606 accept the partial-state for not hanging the UI. 

607 """ 

608 if self._exhausted: 

609 return 

610 self._accessor.cancel() 

611 drained = 0 

612 while drained < _CHAT_STREAM_DRAIN_CAP: 

613 try: 

614 next(self) 

615 except StopIteration: 

616 break 

617 except Exception: 

618 break 

619 drained += 1 

620 self._accessor.clear_abort() 

621 self._exhausted = True 

622 

623 def __del__(self) -> None: # pragma: no cover 

624 with contextlib.suppress(Exception): 

625 self.close() 

626 

627 

628@dataclass(frozen=True) 

629class _RoleSpec: 

630 """Per-role recipe for building a :class:`RoleConfig` from cfg.""" 

631 

632 cfg_attr: str 

633 mode: str 

634 

635 

636_ROLE_SPECS: dict[WorkerRole, _RoleSpec] = { 

637 WorkerRole.EMBED: _RoleSpec(cfg_attr="embedding_model", mode=LoaderMode.EMBED), 

638 WorkerRole.RERANK: _RoleSpec(cfg_attr="reranker_model", mode=LoaderMode.RERANK), 

639 WorkerRole.CHAT: _RoleSpec(cfg_attr="chat_model", mode=LoaderMode.CHAT), 

640 # Vision uses a custom mtmd loader (not load_llama); the mode hint is 

641 # documentation only, the vision worker calls load_vision_llama directly. 

642 WorkerRole.VISION: _RoleSpec(cfg_attr="vision_model", mode="vision"), 

643} 

644 

645 

646_ROLE_ENTRYPOINTS: dict[WorkerRole, Callable[..., None]] = { 

647 WorkerRole.EMBED: embed_worker_main, 

648 WorkerRole.RERANK: rerank_worker_main, 

649 WorkerRole.CHAT: chat_worker_main, 

650 WorkerRole.VISION: vision_worker_main, 

651} 

652 

653 

654def _is_role_configured(role: WorkerRole) -> bool: 

655 """True iff the cfg attribute for *role* holds a non-empty model name.""" 

656 return bool(getattr(cfg, _ROLE_SPECS[role].cfg_attr)) 

657 

658 

659def _make_role_config_factory(role: WorkerRole) -> Callable[[], RoleConfig]: 

660 """Return a factory that resolves the role's configured model at spawn time. 

661 

662 The pool calls the factory on every spawn (lazy or restart) so model 

663 swaps in cfg propagate without an explicit invalidation call. 

664 """ 

665 spec = _ROLE_SPECS[role] 

666 

667 def _make() -> RoleConfig: 

668 model_name = getattr(cfg, spec.cfg_attr) 

669 if not model_name: 

670 raise ProviderError( 

671 f"No {role} model configured. Set cfg.{spec.cfg_attr} first.", 

672 provider="llama-cpp", 

673 ) 

674 return RoleConfig( 

675 role=role, 

676 model_path=resolve_model_path(model_name), 

677 mode=spec.mode, 

678 ) 

679 

680 return _make 

681 

682 

683def resolve_model_path(model: str) -> Path: 

684 """Resolve a model name to a .gguf file path. 

685 Resolution order: 

686 1. Registry (canonical source for installed models) 

687 2. Absolute path (if it points to an existing file) 

688 """ 

689 registry = get_services().registry 

690 try: 

691 return registry.resolve(model) 

692 except (KeyError, ValueError): 

693 pass 

694 

695 # Absolute path to a .gguf file 

696 candidate = Path(model) 

697 if candidate.is_absolute(): 

698 if candidate.exists(): 

699 return candidate 

700 raise ProviderError(f"Model file not found: {model}", provider="llama-cpp") 

701 

702 raise ProviderError( 

703 f"Model {model!r} not found in registry. " 

704 f"Install it via the catalog or 'lilbee model pull'.", 

705 provider="llama-cpp", 

706 ) 

707 

708 

709def _llama_cpp_has_rank_pooling() -> bool: 

710 """Return True iff the installed llama-cpp-python exposes ``LLAMA_POOLING_TYPE_RANK``.""" 

711 # supports_rerank() can be called before any model load (feature detection 

712 # for status / catalog UIs), so route through import_llama_cpp() first to 

713 # surface the libvulkan hint rather than a raw OSError on bare Linux. A 

714 # genuinely-missing llama_cpp package surfaces as ImportError and means 

715 # "no rerank support"; a libvulkan-flavored OSError is a real install 

716 # error and must propagate as ProviderError to the caller. 

717 try: 

718 import_llama_cpp() 

719 from llama_cpp import LLAMA_POOLING_TYPE_RANK # noqa: F401 

720 except ImportError: 

721 return False 

722 return True 

723 

724 

725def load_llama( 

726 model_path: Path, 

727 *, 

728 mode: LoaderMode, 

729 abort_callback_override: Any = None, 

730) -> Any: 

731 """Load a llama_cpp.Llama in chat, embed, or rerank mode. 

732 

733 ``abort_callback_override`` lets pool workers bind a callback that 

734 reads the worker's shared ``mp.Value`` abort flag. 

735 """ 

736 Llama = import_llama_cpp().Llama # noqa: N806 

737 

738 install_llama_log_handler() 

739 embedding = mode in (LoaderMode.EMBED, LoaderMode.RERANK) 

740 kwargs: dict[str, Any] = { 

741 "model_path": str(model_path), 

742 "embedding": embedding, 

743 "verbose": False, 

744 "n_gpu_layers": _resolve_n_gpu_layers(embedding=embedding), 

745 } 

746 if cfg.main_gpu is not None: 

747 kwargs["main_gpu"] = cfg.main_gpu 

748 

749 if embedding: 

750 # Embedding/rerank uses the model's training context unconditionally. 

751 # cfg.num_ctx is a chat-tuned setting; propagating it here used to 

752 # clamp the rerank model below what a query+candidate pair needs and 

753 # produced "llama_decode returned 1" on every other query when the 

754 # user picked a small chat ctx for a low-RAM box. The explicit 

755 # ``embed_train_ctx`` value (instead of ``0`` for "use model 

756 # default") keeps the OOM-retry path working: ``_halve_ctx_for_retry`` 

757 # cannot bisect from 0. 

758 embed_meta = _safe_read_gguf_metadata(model_path) 

759 embed_train_ctx = train_ctx_from_meta( 

760 embed_meta, fallback=_EMBED_FALLBACK_CTX, model_path=model_path 

761 ) 

762 kwargs["n_ctx"] = embed_train_ctx 

763 elif cfg.num_ctx is not None: 

764 kwargs["n_ctx"] = cfg.num_ctx 

765 else: 

766 meta = _safe_read_gguf_metadata(model_path) 

767 kwargs["n_ctx"] = _resolve_chat_ctx(model_path, meta) 

768 log.info( 

769 "Chat n_ctx=%d for %s (dynamic, training_ctx=%s)", 

770 kwargs["n_ctx"], 

771 model_path.name, 

772 (meta or {}).get("context_length", "unknown"), 

773 ) 

774 

775 if embedding: 

776 # llama-cpp-python defaults n_batch = min(n_ctx, 512), silently 

777 # truncating embeddings to 512 tokens. Set n_batch = n_ctx so each 

778 # text can use the model's full context window. 

779 kwargs["n_batch"] = kwargs["n_ctx"] 

780 kwargs["n_ubatch"] = kwargs["n_ctx"] 

781 

782 if mode == LoaderMode.RERANK: 

783 from llama_cpp import LLAMA_POOLING_TYPE_RANK 

784 

785 kwargs["pooling_type"] = LLAMA_POOLING_TYPE_RANK 

786 

787 if not embedding: 

788 _apply_flash_attention(kwargs) 

789 _apply_kv_cache_type(kwargs) 

790 

791 if abort_callback_override is not None: 

792 kwargs["abort_callback"] = abort_callback_override 

793 

794 if embedding: 

795 from lilbee.providers.llama_cpp.batching import EMBED_N_SEQ_MAX 

796 

797 with _llama_n_seq_max(EMBED_N_SEQ_MAX): 

798 return _construct_llama(Llama, model_path, kwargs) 

799 return _construct_llama(Llama, model_path, kwargs) 

800 

801 

802def _safe_read_gguf_metadata(model_path: Path) -> dict[str, str] | None: 

803 """Best-effort GGUF metadata read, returning None on any failure.""" 

804 try: 

805 return read_gguf_metadata(model_path) 

806 except Exception: 

807 log.debug("read_gguf_metadata failed for %s", model_path, exc_info=True) 

808 return None 

809 

810 

811# Fallback used when an embedding GGUF reports zero, negative, or 

812# unparseable ``context_length`` in its metadata header. Some published 

813# nomic-embed and Qwen3 GGUFs in the wild report ``0`` (the b473 QA dump 

814# logged ``n_ctx_seq (512) > n_ctx_train (0)``). 2048 is the documented 

815# training-context for the smallest featured embedder that uses it 

816# (Google's EmbeddingGemma-300m, see 

817# https://huggingface.co/google/embeddinggemma-300m), and llama.cpp 

818# tolerates n_ctx > n_ctx_train with a warning, so the larger nomic 

819# embedder still loads cleanly under the same fallback. 

820_EMBED_FALLBACK_CTX = 2048 

821 

822 

823def _resolve_chat_ctx(model_path: Path, meta: dict[str, str] | None) -> int: 

824 """Pick the largest 256-multiple n_ctx that fits in available memory.""" 

825 training_ctx = train_ctx_from_meta(meta, fallback=DEFAULT_NUM_CTX, model_path=model_path) 

826 ceiling = cfg.num_ctx_max 

827 

828 try: 

829 model_bytes = model_path.stat().st_size 

830 available = get_available_memory(cfg.gpu_memory_fraction) 

831 kv_per_tok = kv_bytes_per_token(meta, _kv_elem_bytes_for_cfg()) 

832 return compute_dynamic_ctx( 

833 model_bytes=model_bytes, 

834 available_bytes=available, 

835 training_ctx=training_ctx, 

836 kv_bytes_per_tok=kv_per_tok, 

837 ceiling=ceiling, 

838 ) 

839 except (OSError, ValueError): 

840 log.debug("dynamic ctx sizing failed for %s, using static cap", model_path, exc_info=True) 

841 return min(training_ctx, DEFAULT_NUM_CTX) 

842 

843 

844def _kv_elem_bytes_for_cfg() -> int: 

845 """Bytes per KV element implied by the configured cache type.""" 

846 return KV_CACHE_TYPE_BYTES[cfg.kv_cache_type] 

847 

848 

849def _resolve_n_gpu_layers(*, embedding: bool) -> int: 

850 """Resolve ``cfg.n_gpu_layers`` (None=all) to llama-cpp's offload integer.""" 

851 if embedding or cfg.n_gpu_layers is None: 

852 return _N_GPU_LAYERS_AUTO 

853 return cfg.n_gpu_layers 

854 

855 

856def _apply_flash_attention(kwargs: dict[str, Any]) -> None: 

857 """Set ``flash_attn`` per ``cfg.flash_attention`` (None=auto, True/False=force).""" 

858 if cfg.flash_attention is False: 

859 return 

860 # None (auto) and True both pass flash_attn=True; the construct loop 

861 # drops it on TypeError if llama-cpp-python doesn't support it. 

862 kwargs["flash_attn"] = True 

863 

864 

865def _apply_kv_cache_type(kwargs: dict[str, Any]) -> None: 

866 """Map ``cfg.kv_cache_type`` to llama-cpp-python ``type_k`` / ``type_v``.""" 

867 if cfg.kv_cache_type is KvCacheType.F16: 

868 return 

869 type_map = _ggml_type_map() 

870 if type_map is None: 

871 log.debug("llama_cpp internal types unavailable; skipping KV quant") 

872 return 

873 ggml_type = type_map.get(cfg.kv_cache_type) 

874 if ggml_type is None: # pragma: no cover -- defensive against new enum values 

875 return 

876 kwargs["type_k"] = ggml_type 

877 kwargs["type_v"] = ggml_type 

878 

879 

880def _ggml_type_map() -> dict[KvCacheType, Any] | None: 

881 """Resolve llama-cpp-python's GGML_TYPE_* constants, or None on older builds.""" 

882 try: 

883 from llama_cpp import llama_cpp as _llc 

884 except Exception: # pragma: no cover -- only fires on llama-cpp-python without _llc 

885 return None 

886 return { 

887 KvCacheType.F32: getattr(_llc, "GGML_TYPE_F32", None), 

888 KvCacheType.F16: getattr(_llc, "GGML_TYPE_F16", None), 

889 KvCacheType.Q8_0: getattr(_llc, "GGML_TYPE_Q8_0", None), 

890 KvCacheType.Q4_0: getattr(_llc, "GGML_TYPE_Q4_0", None), 

891 } 

892 

893 

894def _construct_llama(llama_cls: Any, model_path: Path, kwargs: dict[str, Any]) -> Any: 

895 """Call ``llama_cls(**kwargs)`` with FA fallback and OOM-retry-with-halved-ctx. 

896 

897 Each loop iteration either returns the loaded model, raises (failure 

898 or unrelated TypeError), or continues with halved n_ctx; the loop is 

899 therefore structurally exhaustive and never falls through. 

900 """ 

901 # Fresh abort flag per load: a prior request_abort() that interrupted 

902 # an inference must not latch and abort the next model swap. 

903 clear_abort() 

904 kwargs.setdefault("abort_callback", abort_callback) 

905 fa_dropped = False 

906 for attempt in range(_MAX_OOM_RETRIES + 1): 

907 try: 

908 return suppress_native_stderr(llama_cls, **kwargs) 

909 except TypeError as exc: 

910 if not _drop_flash_attn_if_unsupported(exc, kwargs, fa_dropped): 

911 raise 

912 fa_dropped = True 

913 continue 

914 except ValueError as exc: 

915 if attempt == _MAX_OOM_RETRIES or not _is_load_oom(exc): 

916 _raise_load_error(model_path, kwargs, exc) 

917 if not _halve_ctx_for_retry(kwargs, exc): 

918 _raise_load_error(model_path, kwargs, exc) 

919 raise RuntimeError("unreachable: _construct_llama loop fell through") # pragma: no cover 

920 

921 

922def _drop_flash_attn_if_unsupported( 

923 exc: TypeError, kwargs: dict[str, Any], already_dropped: bool 

924) -> bool: 

925 """If the TypeError is about an unsupported ``flash_attn`` kwarg, drop it.""" 

926 if already_dropped or "flash_attn" not in kwargs or "flash_attn" not in str(exc): 

927 return False 

928 log.info("llama-cpp-python rejected flash_attn=True; retrying without it") 

929 kwargs.pop("flash_attn", None) 

930 return True 

931 

932 

933def _halve_ctx_for_retry(kwargs: dict[str, Any], exc: ValueError) -> bool: 

934 """Halve n_ctx (and matching batch sizes) for an OOM retry. Returns False if no progress.""" 

935 current_ctx = int(kwargs.get("n_ctx", 0) or 0) 

936 if current_ctx <= 0: 

937 return False 

938 new_ctx = max(_CTX_FLOOR, (current_ctx // 2 // _CTX_QUANTUM) * _CTX_QUANTUM) 

939 if new_ctx >= current_ctx: 

940 return False 

941 log.warning( 

942 "llama.cpp load failed at n_ctx=%d (%s); retrying at n_ctx=%d", 

943 current_ctx, 

944 str(exc).splitlines()[0], 

945 new_ctx, 

946 ) 

947 kwargs["n_ctx"] = new_ctx 

948 for key in ("n_batch", "n_ubatch"): 

949 if key in kwargs: 

950 kwargs[key] = new_ctx 

951 return True 

952 

953 

954def _raise_load_error(model_path: Path, kwargs: dict[str, Any], exc: ValueError) -> None: 

955 """Raise the wrapped diagnostic for a llama.cpp load failure, or re-raise as-is.""" 

956 wrapped = _wrap_llama_load_error(model_path, kwargs, exc) 

957 if wrapped is None: 

958 raise exc 

959 raise wrapped from exc 

960 

961 

962def _is_load_oom(exc: ValueError) -> bool: 

963 """Does this ValueError look like a llama.cpp memory failure?""" 

964 err = str(exc) 

965 return "llama_context" in err or "load model from file" in err 

966 

967 

968def _wrap_llama_load_error( 

969 model_path: Path, kwargs: dict[str, Any], exc: ValueError 

970) -> ValueError | None: 

971 """Diagnostic ValueError for opaque llama.cpp load failures, or None to pass through.""" 

972 err = str(exc) 

973 if "llama_context" not in err and "load model from file" not in err: 

974 return None 

975 try: 

976 size_gb = model_path.stat().st_size / (1024**3) if model_path.exists() else 0.0 

977 except OSError: # pragma: no cover 

978 size_gb = 0.0 

979 n_ctx = kwargs.get("n_ctx", 0) 

980 n_ctx_label = n_ctx or "model default" 

981 parts = [ 

982 f"Failed to load {model_path.name} ({size_gb:.1f} GB) with n_ctx={n_ctx_label}.", 

983 ] 

984 try: 

985 import psutil 

986 

987 free_gb = psutil.virtual_memory().available / (1024**3) 

988 parts.append(f"Host has {free_gb:.1f} GB free RAM.") 

989 except Exception as psu_exc: # pragma: no cover 

990 log.debug("psutil unavailable: %s", psu_exc) 

991 parts.append( 

992 "Try a smaller model, lower LILBEE_NUM_CTX, set LILBEE_KV_CACHE_TYPE=q8_0, " 

993 "or close other processes to free RAM. " 

994 f"(llama.cpp: {err})" 

995 ) 

996 return ValueError(" ".join(parts)) 

997 

998 

999def _is_rerank_model(model: str) -> bool: 

1000 """Check if *model* is an exact rerank catalog entry by ref or hf_repo.""" 

1001 if not model: 

1002 return False 

1003 return is_rerank_ref(model)