Coverage for src / lilbee / providers / llama_cpp / provider.py: 100%

468 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""Llama.cpp provider: class, model loader, and path resolution.""" 

2 

3from __future__ import annotations 

4 

5import contextlib 

6import logging 

7import threading 

8from collections.abc import AsyncIterator, Callable 

9from dataclasses import dataclass 

10from pathlib import Path 

11from typing import Any, Literal, cast, overload 

12 

13from lilbee.app.services import get_services 

14from lilbee.catalog import is_rerank_ref 

15from lilbee.core.config import DEFAULT_NUM_CTX, cfg 

16from lilbee.core.config.enums import KV_CACHE_TYPE_BYTES, KvCacheType 

17from lilbee.providers.base import ClosableIterator, LLMProvider, ProviderError, filter_options 

18from lilbee.providers.llama_cpp.abort_signal import abort_callback, clear_abort 

19from lilbee.providers.llama_cpp.gguf_meta import ( 

20 find_mmproj_for_model, 

21 read_gguf_metadata, 

22 train_ctx_from_meta, 

23) 

24from lilbee.providers.llama_cpp.log_dispatch import ( 

25 import_llama_cpp, 

26 install_llama_log_handler, 

27 suppress_native_stderr, 

28) 

29from lilbee.providers.model_cache import ( 

30 LoaderMode, 

31 compute_dynamic_ctx, 

32 get_available_memory, 

33 kv_bytes_per_token, 

34) 

35from lilbee.providers.worker.chat_worker import chat_worker_main 

36from lilbee.providers.worker.embed_worker import embed_worker_main 

37from lilbee.providers.worker.pool import PoolRuntime, RoleAccessor 

38from lilbee.providers.worker.rerank_worker import rerank_worker_main 

39from lilbee.providers.worker.transport import ( 

40 ChatRequest, 

41 OcrBackend, 

42 PdfOcrRequest, 

43 RerankPayload, 

44 RoleConfig, 

45 VisionRequest, 

46 WorkerRole, 

47) 

48from lilbee.providers.worker.transport_pipe import WorkerCrashError, WorkerError 

49from lilbee.providers.worker.vision_worker import vision_worker_main 

50from lilbee.providers.worker.wire_kinds import WireKind 

51from lilbee.runtime.progress import EventType, ExtractEvent 

52from lilbee.vision import PageText, PdfOcrChunk, pdf_page_count 

53 

54log = logging.getLogger(__name__) 

55 

56# Vision OCR sentinel used when no per-call timeout and no ``cfg.ocr_timeout`` 

57# is set. 24h is effectively "no cap" for the round-trip wait loop. 

58_VISION_NO_CAP_TIMEOUT_S = 86_400.0 

59 

60_LLAMA_CONTEXT_PATCH_LOCK = threading.Lock() 

61"""Serialises overlapping ``_llama_n_seq_max`` callers inside one process. 

62 

63The shim mutates ``llama_cpp.internals.LlamaContext.__init__`` globally 

64while the with-block is open. Worker subprocesses each load one model 

65serially today, but the lock keeps the contract safe if a future caller 

66loads two models concurrently. 

67""" 

68 

69 

70@contextlib.contextmanager 

71def _llama_n_seq_max(n_seq_max: int) -> Any: 

72 """Set ``context_params.n_seq_max`` on the next ``LlamaContext`` constructed. 

73 

74 Workaround for llama-cpp-python upstream issue #2051 (``n_seq_max`` 

75 not exposed as a Llama kwarg). See ``docs/architecture.md`` for the 

76 full rationale and the upstream-fix removal hint. 

77 """ 

78 from llama_cpp import internals 

79 

80 with _LLAMA_CONTEXT_PATCH_LOCK: 

81 original = internals.LlamaContext.__init__ 

82 

83 def patched(self: Any, *, model: Any, params: Any, verbose: bool) -> None: 

84 params.n_seq_max = n_seq_max 

85 original(self, model=model, params=params, verbose=verbose) 

86 

87 internals.LlamaContext.__init__ = patched # type: ignore[method-assign,assignment] 

88 try: 

89 yield 

90 finally: 

91 internals.LlamaContext.__init__ = original # type: ignore[method-assign] 

92 

93 

94# Cap on tokens drained during ``_PoolChatStreamIterator.close()`` after a 

95# mid-stream cancel. A runaway model (Qwen3-0.6B stuck in a never-closing 

96# ``<think>`` loop) would otherwise block close() indefinitely. 

97_CHAT_STREAM_DRAIN_CAP = 1024 

98 

99# Chat-load OOM retry knobs. The OOM wrapper halves ``n_ctx`` (rounded down to 

100# the next ``_CTX_QUANTUM`` multiple) up to ``_MAX_OOM_RETRIES`` times before 

101# raising. ``_CTX_FLOOR`` is the smallest ``n_ctx`` we'll attempt. 

102_MAX_OOM_RETRIES = 2 

103_CTX_QUANTUM = 256 

104_CTX_FLOOR = 512 

105 

106# Sentinel passed to ``llama-cpp-python`` for "offload all layers". 

107_N_GPU_LAYERS_AUTO = -1 

108 

109# lilbee extracts a single scalar rerank score, so a multi-class reranker would 

110# be silently scored on one logit. The response can't reveal the class count, so 

111# it's checked at load. 

112_SINGLE_CLASS_RERANK = 1 

113 

114 

115def _read_context_n_seq_max(llm: Any) -> int | None: 

116 """Read ``n_seq_max`` back off a constructed Llama, or None if unreadable. 

117 

118 AttributeError means llama-cpp restructured its internals (the signal to 

119 revisit against the pin); the caller treats None as "could not verify". 

120 """ 

121 try: 

122 value = llm.context_params.n_seq_max 

123 except AttributeError: 

124 return None 

125 return int(value) if isinstance(value, int) else None 

126 

127 

128def _read_rerank_class_count(llm: Any) -> int | None: 

129 """Read a reranker's classifier-output count, or None if unreadable.""" 

130 try: 

131 model = llm._model.model 

132 except AttributeError: 

133 return None 

134 from llama_cpp import llama_cpp as _llc 

135 

136 return int(_llc.llama_model_n_cls_out(model)) 

137 

138 

139def _verify_embed_n_seq_max(llm: Any, expected: int) -> None: 

140 """Fail loudly if the ``_llama_n_seq_max`` shim did not take effect. 

141 

142 Without this, a future llama-cpp-python that moves ``LlamaContext`` 

143 internals leaves ``n_seq_max`` at the C default (1) and batched 2+ embed 

144 silently returns ``llama_decode -1`` at inference time instead of here. 

145 """ 

146 applied = _read_context_n_seq_max(llm) 

147 if applied is not None and applied != expected: 

148 raise ProviderError( 

149 f"Embedding context n_seq_max={applied}, expected {expected}; " 

150 "the llama-cpp-python n_seq_max workaround did not apply " 

151 "(internals may have changed). Batched embed would fail at decode.", 

152 provider="llama-cpp", 

153 ) 

154 

155 

156def _verify_single_class_reranker(llm: Any) -> None: 

157 """Fail loudly if the reranker emits more than one classifier output. 

158 

159 lilbee reads ``embedding[0]`` as the relevance score, which is the model's 

160 only output for a single-class cross-encoder. A multi-class reranker would 

161 silently score on class 0 alone, so reject it instead of producing 

162 plausible-but-wrong rankings. 

163 """ 

164 count = _read_rerank_class_count(llm) 

165 if count is not None and count != _SINGLE_CLASS_RERANK: 

166 raise ProviderError( 

167 f"Reranker has {count} classifier outputs; lilbee supports only " 

168 "single-class (n_cls_out=1) rerankers. A multi-class model would " 

169 "be scored on class 0 alone and rank incorrectly.", 

170 provider="llama-cpp", 

171 ) 

172 

173 

174class LlamaCppProvider(LLMProvider): 

175 """Provider backed by llama-cpp-python for local GGUF model inference.""" 

176 

177 def __init__(self) -> None: 

178 self._pool_lock = threading.Lock() 

179 self._registered_roles: set[WorkerRole] = set() 

180 

181 @staticmethod 

182 def _worker_error_message(role_label: str, exc: WorkerError) -> str: 

183 """Render a user-facing message that names the role and points at the log. 

184 

185 ``WorkerCrashError`` already embeds the log path in its message; for 

186 plain ``WorkerError`` (the worker reported an exception or returned 

187 a malformed reply) the surfaced text is the worker's exception 

188 repr so the user sees enough to file a bug report. 

189 """ 

190 detail = str(exc) 

191 if detail.endswith("."): # the wrapper supplies its own sentence-final period 

192 detail = detail[:-1] 

193 if isinstance(exc, WorkerCrashError): 

194 return f"{role_label} worker exited unexpectedly. {detail}. Please try again." 

195 return f"{role_label} worker reported an error: {detail}. Please try again." 

196 

197 def _pool_runtime(self) -> PoolRuntime: 

198 """Return the Services-owned :class:`PoolRuntime`, starting it lazily.""" 

199 runtime = get_services().pool_runtime 

200 runtime.start() 

201 return runtime 

202 

203 def _get_pool_accessor( 

204 self, 

205 role: WorkerRole, 

206 worker_main: Any, 

207 config_factory: Callable[[], RoleConfig], 

208 ) -> RoleAccessor: 

209 """Register *role* on the Services pool the first time it is used. 

210 

211 Subsequent calls return the same accessor without touching the 

212 pool state. Registration is gated by ``self._pool_lock`` so two 

213 concurrent first-callers do not race to register the role twice. 

214 """ 

215 pool = get_services().worker_pool 

216 with self._pool_lock: 

217 if role not in self._registered_roles: 

218 pool.register(role, worker_main, config_factory) 

219 self._registered_roles.add(role) 

220 return pool.accessor(role) 

221 

222 def embed(self, texts: list[str]) -> list[list[float]]: 

223 """Embed texts via the persistent pool worker. 

224 

225 Worker crashes and timeouts surface as :class:`ProviderError`; 

226 the pool respawns the embed role lazily on the next call. 

227 """ 

228 accessor = self._get_pool_accessor( 

229 WorkerRole.EMBED, embed_worker_main, _make_role_config_factory(WorkerRole.EMBED) 

230 ) 

231 runtime = self._pool_runtime() 

232 try: 

233 result = runtime.run_sync( 

234 accessor.call(WireKind.EMBED, texts, timeout=cfg.worker_pool_call_timeout_s), 

235 timeout=cfg.worker_pool_call_timeout_s, 

236 ) 

237 if not isinstance(result, list): 

238 raise WorkerError( 

239 "ProtocolError", 

240 f"Pool embed returned {type(result).__name__}, expected list[list[float]].", 

241 "", 

242 ) 

243 except WorkerError as exc: 

244 raise ProviderError( 

245 self._worker_error_message("Embedding", exc), 

246 provider="llama-cpp", 

247 ) from exc 

248 except TimeoutError as exc: 

249 raise ProviderError( 

250 "Embedding worker timed out. Please try again.", 

251 provider="llama-cpp", 

252 ) from exc 

253 return result 

254 

255 def rerank(self, query: str, candidates: list[str]) -> list[float]: 

256 """Score *candidates* by relevance to *query* via the pool worker.""" 

257 if not candidates: 

258 return [] 

259 accessor = self._get_pool_accessor( 

260 WorkerRole.RERANK, rerank_worker_main, _make_role_config_factory(WorkerRole.RERANK) 

261 ) 

262 runtime = self._pool_runtime() 

263 try: 

264 result = runtime.run_sync( 

265 accessor.call( 

266 WireKind.RERANK, 

267 RerankPayload(query=query, candidates=candidates), 

268 timeout=cfg.worker_pool_call_timeout_s, 

269 ), 

270 timeout=cfg.worker_pool_call_timeout_s, 

271 ) 

272 if not isinstance(result, list): 

273 raise WorkerError( 

274 "ProtocolError", 

275 f"Pool rerank returned {type(result).__name__}, expected list[float].", 

276 "", 

277 ) 

278 except WorkerError as exc: 

279 raise ProviderError( 

280 self._worker_error_message("Rerank", exc), 

281 provider="llama-cpp", 

282 ) from exc 

283 except TimeoutError as exc: 

284 raise ProviderError( 

285 "Rerank worker timed out. Please try again.", 

286 provider="llama-cpp", 

287 ) from exc 

288 return result 

289 

290 def supports_rerank(self) -> bool: 

291 """llama-cpp can rerank iff llama-cpp-python exposes the rank pooling type.""" 

292 return _llama_cpp_has_rank_pooling() 

293 

294 def vision_ocr( 

295 self, png_bytes: bytes, model: str, prompt: str = "", *, timeout: float | None = None 

296 ) -> str: 

297 """Run vision OCR via the persistent pool worker.""" 

298 accessor = self._get_pool_accessor( 

299 WorkerRole.VISION, vision_worker_main, _make_role_config_factory(WorkerRole.VISION) 

300 ) 

301 runtime = self._pool_runtime() 

302 budget = self._vision_call_budget(timeout) 

303 request = VisionRequest(png_bytes=png_bytes, prompt=prompt, model=model or None) 

304 try: 

305 result = runtime.run_sync( 

306 accessor.call(WireKind.VISION, request, timeout=budget), 

307 timeout=budget, 

308 ) 

309 if not isinstance(result, str): 

310 raise WorkerError( 

311 "ProtocolError", 

312 f"Pool vision_ocr returned {type(result).__name__}, expected str.", 

313 "", 

314 ) 

315 except WorkerError as exc: 

316 raise ProviderError( 

317 self._worker_error_message("Vision", exc), 

318 provider="llama-cpp", 

319 ) from exc 

320 except TimeoutError as exc: 

321 raise ProviderError( 

322 "Vision worker timed out. Please try again.", 

323 provider="llama-cpp", 

324 ) from exc 

325 return result 

326 

327 @staticmethod 

328 def _vision_call_budget(timeout: float | None) -> float: 

329 """Wall-clock budget for one vision_ocr call (per-call > cfg.ocr_timeout > no cap).""" 

330 effective = timeout if timeout is not None else cfg.ocr_timeout 

331 return float(effective) if effective and effective > 0 else _VISION_NO_CAP_TIMEOUT_S 

332 

333 def pdf_ocr( 

334 self, 

335 path: Path, 

336 *, 

337 backend: OcrBackend, 

338 model: str = "", 

339 per_page_timeout_s: float | None = None, 

340 quiet: bool = True, 

341 on_progress: Callable[..., None] | None = None, 

342 ) -> list[PageText]: 

343 """Run multi-page vision PDF OCR via the persistent vision worker. 

344 

345 ``per_page_timeout_s`` is *per page*. The total wall-clock cap on 

346 the streamed drain is ``pages * per_page + cfg.vision_load_budget_s`` 

347 (load grace), so a 100-page scan with a 60 s per-page budget gets 

348 ~6000 s + load, not 60 s for the whole document. 

349 """ 

350 accessor = self._get_pool_accessor( 

351 WorkerRole.VISION, vision_worker_main, _make_role_config_factory(WorkerRole.VISION) 

352 ) 

353 runtime = self._pool_runtime() 

354 budget = self._pdf_drain_budget(path, per_page_timeout_s) 

355 del quiet # accepted for Protocol parity; worker has no Rich progress to suppress. 

356 request = PdfOcrRequest( 

357 path=str(path), 

358 backend=backend, 

359 model=model, 

360 ) 

361 progress = on_progress 

362 

363 async def _drain() -> list[PageText]: 

364 pages: list[PageText] = [] 

365 stream = cast(AsyncIterator[Any], accessor.stream(WireKind.PDF_OCR, request)) 

366 async for frame in stream: 

367 if not isinstance(frame, PdfOcrChunk): 

368 raise ProviderError( 

369 f"PDF OCR worker streamed unexpected frame type {type(frame).__name__}.", 

370 provider="llama-cpp", 

371 ) 

372 pages.append(PageText(frame.page, frame.text)) 

373 if progress is not None: 

374 progress( 

375 EventType.EXTRACT, 

376 ExtractEvent(file=path.name, page=frame.page, total_pages=frame.total), 

377 ) 

378 return pages 

379 

380 try: 

381 return runtime.run_sync(_drain(), timeout=budget) 

382 except WorkerError as exc: 

383 raise ProviderError( 

384 self._worker_error_message("PDF OCR", exc), 

385 provider="llama-cpp", 

386 ) from exc 

387 except TimeoutError as exc: 

388 raise ProviderError( 

389 "PDF OCR worker timed out. Please try again.", 

390 provider="llama-cpp", 

391 ) from exc 

392 

393 def _pdf_drain_budget(self, path: Path, per_page_timeout_s: float | None) -> float: 

394 """Total drain timeout = page_count * per_page + vision_load_budget_s.""" 

395 if not per_page_timeout_s or per_page_timeout_s <= 0: 

396 return _VISION_NO_CAP_TIMEOUT_S 

397 try: 

398 pages = pdf_page_count(path) 

399 except Exception: 

400 # If we can't probe pages upfront the worker will still try, 

401 # but we lose the precise budget; fall back to no-cap so the 

402 # parent doesn't kill a valid run on a probe failure. 

403 return _VISION_NO_CAP_TIMEOUT_S 

404 return float(pages) * per_page_timeout_s + cfg.vision_load_budget_s 

405 

406 @overload 

407 def chat( 

408 self, 

409 messages: list[dict[str, str]], 

410 *, 

411 stream: Literal[False] = False, 

412 options: dict[str, Any] | None = None, 

413 model: str | None = None, 

414 ) -> str: ... 

415 

416 @overload 

417 def chat( 

418 self, 

419 messages: list[dict[str, str]], 

420 *, 

421 stream: Literal[True], 

422 options: dict[str, Any] | None = None, 

423 model: str | None = None, 

424 ) -> ClosableIterator[str]: ... 

425 

426 def chat( 

427 self, 

428 messages: list[dict[str, str]], 

429 *, 

430 stream: bool = False, 

431 options: dict[str, Any] | None = None, 

432 model: str | None = None, 

433 ) -> str | ClosableIterator[str]: 

434 """Chat completion via the persistent pool worker. 

435 

436 Streaming returns a :class:`ClosableIterator[str]` whose 

437 ``close()`` flips the worker's abort flag so in-flight generation 

438 drains cleanly. Non-streaming returns the assembled assistant text. 

439 """ 

440 accessor = self._get_pool_accessor( 

441 WorkerRole.CHAT, chat_worker_main, _make_role_config_factory(WorkerRole.CHAT) 

442 ) 

443 runtime = self._pool_runtime() 

444 accessor.clear_abort() # honor mid-stream cancels from the previous turn 

445 request = ChatRequest( 

446 messages=messages, 

447 stream=stream, 

448 options=self._chat_kwargs_from_options(options) or None, 

449 model=model, 

450 ) 

451 if stream: 

452 return _PoolChatStreamIterator( 

453 runtime=runtime, 

454 accessor=accessor, 

455 async_iter=accessor.stream(WireKind.CHAT, request), 

456 ) 

457 try: 

458 result = runtime.run_sync( 

459 accessor.call(WireKind.CHAT, request, timeout=cfg.worker_pool_call_timeout_s), 

460 timeout=cfg.worker_pool_call_timeout_s, 

461 ) 

462 if not isinstance(result, str): 

463 raise WorkerError( 

464 "ProtocolError", 

465 f"Pool chat returned {type(result).__name__}, expected str.", 

466 "", 

467 ) 

468 except WorkerError as exc: 

469 raise ProviderError( 

470 self._worker_error_message("Chat", exc), 

471 provider="llama-cpp", 

472 ) from exc 

473 except TimeoutError as exc: 

474 raise ProviderError( 

475 "Chat worker timed out. Please try again.", 

476 provider="llama-cpp", 

477 ) from exc 

478 return result 

479 

480 @staticmethod 

481 def _chat_kwargs_from_options(options: dict[str, Any] | None) -> dict[str, Any]: 

482 """Translate user-facing options into llama-cpp create_chat_completion kwargs.""" 

483 if not options: 

484 return {} 

485 filtered = filter_options(options) 

486 if "num_predict" in filtered: 

487 filtered["max_tokens"] = filtered.pop("num_predict") 

488 filtered.pop("num_ctx", None) # model-load param, not per-call 

489 # top_k kept here (local llama.cpp honors it); the API path in 

490 # translate_options strips it since hosted providers ignore it. 

491 return filtered 

492 

493 def list_models(self) -> list[str]: 

494 """List installed models from registry.""" 

495 registry = get_services().registry 

496 return sorted(m.ref for m in registry.list_installed()) 

497 

498 def list_chat_models(self, provider: str) -> list[str]: 

499 """llama-cpp has no frontier-provider catalog; always ``[]``.""" 

500 return [] 

501 

502 def pull_model(self, model: str, *, on_progress: Callable[..., Any] | None = None) -> None: 

503 """Not supported directly: ``lilbee.catalog`` handles downloads.""" 

504 raise NotImplementedError( 

505 f"llama-cpp provider cannot pull model {model!r}. " 

506 "Download GGUF files manually or use the catalog." 

507 ) 

508 

509 def show_model(self, model: str) -> dict[str, Any] | None: 

510 """Return model metadata from GGUF headers.""" 

511 try: 

512 path = resolve_model_path(model) 

513 except ProviderError: 

514 return None 

515 return read_gguf_metadata(path) 

516 

517 def get_capabilities(self, model: str) -> list[str]: 

518 """Detect capabilities from local GGUF files. 

519 

520 Rerank models return ``["rerank"]``; cross-encoder GGUFs cannot 

521 generate text. Other models report ``"completion"``, plus 

522 ``"vision"`` when an mmproj sidecar is present. 

523 """ 

524 if _is_rerank_model(model): 

525 return ["rerank"] 

526 caps: list[str] = ["completion"] 

527 try: 

528 path = resolve_model_path(model) 

529 except ProviderError: 

530 log.debug("resolve_model_path failed for %s", model, exc_info=True) 

531 return caps 

532 try: 

533 find_mmproj_for_model(path) 

534 caps.append("vision") 

535 except ProviderError: 

536 log.debug("no mmproj for %s", model, exc_info=True) 

537 return caps 

538 

539 def warm_up_pool(self) -> None: 

540 """Register roles for every configured model. Idempotent. 

541 

542 Called by ``Services`` when ``cfg.worker_pool_eager_start`` is on so 

543 ``WorkerPool.start_eager()`` has roles to spawn. Roles whose model is 

544 unset are skipped; this lets a setup with only ``chat_model`` + 

545 ``embedding_model`` configured eager-start exactly those two and not 

546 pay rerank or vision spawn cost. 

547 """ 

548 for role, _spec in _ROLE_SPECS.items(): 

549 if not _is_role_configured(role): 

550 continue 

551 entrypoint = _ROLE_ENTRYPOINTS[role] 

552 self._get_pool_accessor(role, entrypoint, _make_role_config_factory(role)) 

553 

554 def shutdown(self) -> None: 

555 """Drop pool registrations so a follow-up provider can re-register cleanly.""" 

556 self._release_pool_roles() 

557 

558 def _release_pool_roles(self) -> None: 

559 """Drop our registrations on the Services pool so the next call respawns. 

560 

561 Safe even when Services has not yet been built (early shutdown 

562 on import-time failure). Holds ``self._pool_lock`` so a concurrent 

563 ``_get_pool_accessor`` does not race the role removal. 

564 """ 

565 with self._pool_lock: 

566 roles = tuple(self._registered_roles) 

567 self._registered_roles.clear() 

568 if not roles: 

569 return 

570 from lilbee.providers.worker.pool import PoolShutdownError 

571 

572 services = get_services() 

573 runtime = services.pool_runtime 

574 for role in roles: 

575 try: 

576 runtime.run_sync(services.worker_pool.release(role), timeout=10.0) 

577 except PoolShutdownError: 

578 # Pool already shut down (atexit ordering during a CLI exit 

579 # tears down the pool runtime before this provider). Nothing 

580 # to release; silent no-op. 

581 pass 

582 except (TimeoutError, RuntimeError, OSError) as exc: 

583 log.warning("Pool release of role=%s raised %s", role, exc) 

584 

585 def invalidate_load_cache(self, model_path: Path | None = None) -> None: 

586 """Drop the pool's per-role workers so the next call respawns with current settings. 

587 

588 The ``model_path`` argument is accepted for protocol parity with 

589 other providers but does not narrow the scope: workers reload all 

590 their roles on respawn anyway. 

591 """ 

592 del model_path 

593 self._release_pool_roles() 

594 

595 

596class _PoolChatStreamIterator: 

597 """Sync facade over an async chat-stream iterator from the worker pool. 

598 

599 Each ``__next__`` submits one ``__anext__`` to the pool's runtime 

600 loop and blocks for the result. ``close()`` flips the worker's abort 

601 flag so any in-flight generation stops at the next token-tick; 

602 in-flight chunks already in the pipe still drain. 

603 """ 

604 

605 def __init__( 

606 self, 

607 *, 

608 runtime: PoolRuntime, 

609 accessor: RoleAccessor, 

610 async_iter: Any, 

611 ) -> None: 

612 self._runtime = runtime 

613 self._accessor = accessor 

614 self._async_iter = async_iter 

615 self._exhausted = False 

616 

617 def __iter__(self) -> _PoolChatStreamIterator: 

618 return self 

619 

620 def __next__(self) -> str: 

621 if self._exhausted: 

622 raise StopIteration 

623 try: 

624 chunk: str = self._runtime.run_sync( 

625 self._async_iter.__anext__(), 

626 timeout=cfg.worker_pool_call_timeout_s, 

627 ) 

628 return chunk 

629 except StopAsyncIteration: 

630 self._exhausted = True 

631 raise StopIteration from None 

632 except WorkerError as exc: 

633 # Mid-stream worker crashes propagate as ProviderError so the 

634 # streaming path matches the non-streaming contract. 

635 self._exhausted = True 

636 raise ProviderError( 

637 LlamaCppProvider._worker_error_message("Chat", exc), 

638 provider="llama-cpp", 

639 ) from exc 

640 except TimeoutError as exc: 

641 self._exhausted = True 

642 raise ProviderError( 

643 "Chat worker timed out mid-stream. Please try again.", 

644 provider="llama-cpp", 

645 ) from exc 

646 

647 def close(self) -> None: 

648 """Cancel mid-stream and drain remaining tokens from the pipe. 

649 

650 Drain is bounded by ``_CHAT_STREAM_DRAIN_CAP`` so a stuck 

651 worker cannot block close() indefinitely; once the cap fires we 

652 accept the partial-state for not hanging the UI. 

653 """ 

654 if self._exhausted: 

655 return 

656 self._accessor.cancel() 

657 drained = 0 

658 while drained < _CHAT_STREAM_DRAIN_CAP: 

659 try: 

660 next(self) 

661 except StopIteration: 

662 break 

663 except Exception: 

664 break 

665 drained += 1 

666 self._accessor.clear_abort() 

667 self._exhausted = True 

668 

669 def __del__(self) -> None: # pragma: no cover 

670 with contextlib.suppress(Exception): 

671 self.close() 

672 

673 

674@dataclass(frozen=True) 

675class _RoleSpec: 

676 """Per-role recipe for building a :class:`RoleConfig` from cfg.""" 

677 

678 cfg_attr: str 

679 mode: str 

680 

681 

682_ROLE_SPECS: dict[WorkerRole, _RoleSpec] = { 

683 WorkerRole.EMBED: _RoleSpec(cfg_attr="embedding_model", mode=LoaderMode.EMBED), 

684 WorkerRole.RERANK: _RoleSpec(cfg_attr="reranker_model", mode=LoaderMode.RERANK), 

685 WorkerRole.CHAT: _RoleSpec(cfg_attr="chat_model", mode=LoaderMode.CHAT), 

686 # Vision uses a custom mtmd loader (not load_llama); the mode hint is 

687 # documentation only, the vision worker calls load_vision_llama directly. 

688 WorkerRole.VISION: _RoleSpec(cfg_attr="vision_model", mode="vision"), 

689} 

690 

691 

692_ROLE_ENTRYPOINTS: dict[WorkerRole, Callable[..., None]] = { 

693 WorkerRole.EMBED: embed_worker_main, 

694 WorkerRole.RERANK: rerank_worker_main, 

695 WorkerRole.CHAT: chat_worker_main, 

696 WorkerRole.VISION: vision_worker_main, 

697} 

698 

699 

700def _is_role_configured(role: WorkerRole) -> bool: 

701 """True iff the cfg attribute for *role* holds a non-empty model name.""" 

702 return bool(getattr(cfg, _ROLE_SPECS[role].cfg_attr)) 

703 

704 

705def _make_role_config_factory(role: WorkerRole) -> Callable[[], RoleConfig]: 

706 """Return a factory that resolves the role's configured model at spawn time. 

707 

708 The pool calls the factory on every spawn (lazy or restart) so model 

709 swaps in cfg propagate without an explicit invalidation call. 

710 """ 

711 spec = _ROLE_SPECS[role] 

712 

713 def _make() -> RoleConfig: 

714 model_name = getattr(cfg, spec.cfg_attr) 

715 if not model_name: 

716 raise ProviderError( 

717 f"No {role} model configured. Set cfg.{spec.cfg_attr} first.", 

718 provider="llama-cpp", 

719 ) 

720 return RoleConfig( 

721 role=role, 

722 model_path=resolve_model_path(model_name), 

723 mode=spec.mode, 

724 ) 

725 

726 return _make 

727 

728 

729def resolve_model_path(model: str) -> Path: 

730 """Resolve a model name to a .gguf file path. 

731 Resolution order: 

732 1. Registry (canonical source for installed models) 

733 2. Absolute path (if it points to an existing file) 

734 """ 

735 registry = get_services().registry 

736 try: 

737 return registry.resolve(model) 

738 except (KeyError, ValueError): 

739 pass 

740 

741 # Absolute path to a .gguf file 

742 candidate = Path(model) 

743 if candidate.is_absolute(): 

744 if candidate.exists(): 

745 return candidate 

746 raise ProviderError(f"Model file not found: {model}", provider="llama-cpp") 

747 

748 raise ProviderError( 

749 f"Model {model!r} not found in registry. " 

750 f"Install it via the catalog or 'lilbee model pull'.", 

751 provider="llama-cpp", 

752 ) 

753 

754 

755def _llama_cpp_has_rank_pooling() -> bool: 

756 """Return True iff the installed llama-cpp-python exposes ``LLAMA_POOLING_TYPE_RANK``.""" 

757 # supports_rerank() can be called before any model load (feature detection 

758 # for status / catalog UIs), so route through import_llama_cpp() first to 

759 # surface the libvulkan hint rather than a raw OSError on bare Linux. A 

760 # genuinely-missing llama_cpp package surfaces as ImportError and means 

761 # "no rerank support"; a libvulkan-flavored OSError is a real install 

762 # error and must propagate as ProviderError to the caller. 

763 try: 

764 import_llama_cpp() 

765 from llama_cpp import LLAMA_POOLING_TYPE_RANK # noqa: F401 

766 except ImportError: 

767 return False 

768 return True 

769 

770 

771def load_llama( 

772 model_path: Path, 

773 *, 

774 mode: LoaderMode, 

775 abort_callback_override: Any = None, 

776) -> Any: 

777 """Load a llama_cpp.Llama in chat, embed, or rerank mode. 

778 

779 ``abort_callback_override`` lets pool workers bind a callback that 

780 reads the worker's shared ``mp.Value`` abort flag. 

781 """ 

782 Llama = import_llama_cpp().Llama # noqa: N806 

783 

784 install_llama_log_handler() 

785 embedding = mode in (LoaderMode.EMBED, LoaderMode.RERANK) 

786 kwargs: dict[str, Any] = { 

787 "model_path": str(model_path), 

788 "embedding": embedding, 

789 "verbose": False, 

790 "n_gpu_layers": _resolve_n_gpu_layers(embedding=embedding), 

791 } 

792 if cfg.main_gpu is not None: 

793 kwargs["main_gpu"] = cfg.main_gpu 

794 

795 if embedding: 

796 # Embedding/rerank uses the model's training context unconditionally. 

797 # cfg.num_ctx is a chat-tuned setting; propagating it here used to 

798 # clamp the rerank model below what a query+candidate pair needs and 

799 # produced "llama_decode returned 1" on every other query when the 

800 # user picked a small chat ctx for a low-RAM box. The explicit 

801 # ``embed_train_ctx`` value (instead of ``0`` for "use model 

802 # default") keeps the OOM-retry path working: ``_halve_ctx_for_retry`` 

803 # cannot bisect from 0. 

804 embed_meta = _safe_read_gguf_metadata(model_path) 

805 embed_train_ctx = train_ctx_from_meta( 

806 embed_meta, fallback=_EMBED_FALLBACK_CTX, model_path=model_path 

807 ) 

808 kwargs["n_ctx"] = embed_train_ctx 

809 elif cfg.num_ctx is not None: 

810 kwargs["n_ctx"] = cfg.num_ctx 

811 else: 

812 meta = _safe_read_gguf_metadata(model_path) 

813 kwargs["n_ctx"] = _resolve_chat_ctx(model_path, meta) 

814 log.info( 

815 "Chat n_ctx=%d for %s (dynamic, training_ctx=%s)", 

816 kwargs["n_ctx"], 

817 model_path.name, 

818 (meta or {}).get("context_length", "unknown"), 

819 ) 

820 

821 if embedding: 

822 # llama-cpp-python defaults n_batch = min(n_ctx, 512), silently 

823 # truncating embeddings to 512 tokens. Set n_batch = n_ctx so each 

824 # text can use the model's full context window. 

825 kwargs["n_batch"] = kwargs["n_ctx"] 

826 kwargs["n_ubatch"] = kwargs["n_ctx"] 

827 

828 if mode == LoaderMode.RERANK: 

829 from llama_cpp import LLAMA_POOLING_TYPE_RANK 

830 

831 kwargs["pooling_type"] = LLAMA_POOLING_TYPE_RANK 

832 

833 if not embedding: 

834 _apply_flash_attention(kwargs) 

835 _apply_kv_cache_type(kwargs) 

836 

837 if abort_callback_override is not None: 

838 kwargs["abort_callback"] = abort_callback_override 

839 

840 if embedding: 

841 from lilbee.providers.llama_cpp.batching import EMBED_N_SEQ_MAX 

842 

843 with _llama_n_seq_max(EMBED_N_SEQ_MAX): 

844 llm = _construct_llama(Llama, model_path, kwargs) 

845 _verify_embed_n_seq_max(llm, EMBED_N_SEQ_MAX) 

846 if mode == LoaderMode.RERANK: 

847 _verify_single_class_reranker(llm) 

848 return llm 

849 return _construct_llama(Llama, model_path, kwargs) 

850 

851 

852def _safe_read_gguf_metadata(model_path: Path) -> dict[str, str] | None: 

853 """Best-effort GGUF metadata read, returning None on any failure.""" 

854 try: 

855 return read_gguf_metadata(model_path) 

856 except Exception: 

857 log.debug("read_gguf_metadata failed for %s", model_path, exc_info=True) 

858 return None 

859 

860 

861# Fallback used when an embedding GGUF reports zero, negative, or 

862# unparseable ``context_length`` in its metadata header. Some published 

863# nomic-embed and Qwen3 GGUFs in the wild report ``0`` (the b473 QA dump 

864# logged ``n_ctx_seq (512) > n_ctx_train (0)``). 2048 is the documented 

865# training-context for the smallest featured embedder that uses it 

866# (Google's EmbeddingGemma-300m, see 

867# https://huggingface.co/google/embeddinggemma-300m), and llama.cpp 

868# tolerates n_ctx > n_ctx_train with a warning, so the larger nomic 

869# embedder still loads cleanly under the same fallback. 

870_EMBED_FALLBACK_CTX = 2048 

871 

872 

873def _resolve_chat_ctx(model_path: Path, meta: dict[str, str] | None) -> int: 

874 """Pick n_ctx aiming for ``cfg.chat_n_ctx_target``, clamped to model + host. 

875 

876 When ``cfg.num_ctx_max`` is ``None`` the model's training_ctx is the only 

877 ceiling, so a long-context model can grow past the target if the host 

878 has the RAM to back it. Setting ``num_ctx_max`` explicitly caps below 

879 training_ctx for per-host policy reasons. 

880 """ 

881 training_ctx = train_ctx_from_meta(meta, fallback=DEFAULT_NUM_CTX, model_path=model_path) 

882 ceiling = cfg.num_ctx_max if cfg.num_ctx_max is not None else training_ctx 

883 

884 try: 

885 model_bytes = model_path.stat().st_size 

886 available = get_available_memory(cfg.gpu_memory_fraction) 

887 kv_per_tok = kv_bytes_per_token(meta, _kv_elem_bytes_for_cfg()) 

888 return compute_dynamic_ctx( 

889 model_bytes=model_bytes, 

890 available_bytes=available, 

891 training_ctx=training_ctx, 

892 kv_bytes_per_tok=kv_per_tok, 

893 ceiling=ceiling, 

894 target=cfg.chat_n_ctx_target, 

895 ) 

896 except (OSError, ValueError): 

897 log.debug("dynamic ctx sizing failed for %s, using static cap", model_path, exc_info=True) 

898 return min(training_ctx, cfg.chat_n_ctx_target) 

899 

900 

901def _kv_elem_bytes_for_cfg() -> int: 

902 """Bytes per KV element implied by the configured cache type.""" 

903 return KV_CACHE_TYPE_BYTES[cfg.kv_cache_type] 

904 

905 

906def _resolve_n_gpu_layers(*, embedding: bool) -> int: 

907 """Resolve ``cfg.n_gpu_layers`` (None=all) to llama-cpp's offload integer.""" 

908 if embedding or cfg.n_gpu_layers is None: 

909 return _N_GPU_LAYERS_AUTO 

910 return cfg.n_gpu_layers 

911 

912 

913def _apply_flash_attention(kwargs: dict[str, Any]) -> None: 

914 """Set ``flash_attn`` per ``cfg.flash_attention`` (None=auto, True/False=force).""" 

915 if cfg.flash_attention is False: 

916 return 

917 # None (auto) and True both pass flash_attn=True; the construct loop 

918 # drops it on TypeError if llama-cpp-python doesn't support it. 

919 kwargs["flash_attn"] = True 

920 

921 

922def _apply_kv_cache_type(kwargs: dict[str, Any]) -> None: 

923 """Map ``cfg.kv_cache_type`` to llama-cpp-python ``type_k`` / ``type_v``.""" 

924 if cfg.kv_cache_type is KvCacheType.F16: 

925 return 

926 type_map = _ggml_type_map() 

927 if type_map is None: 

928 log.debug("llama_cpp internal types unavailable; skipping KV quant") 

929 return 

930 ggml_type = type_map.get(cfg.kv_cache_type) 

931 if ggml_type is None: # pragma: no cover -- defensive against new enum values 

932 return 

933 kwargs["type_k"] = ggml_type 

934 kwargs["type_v"] = ggml_type 

935 

936 

937def _ggml_type_map() -> dict[KvCacheType, Any] | None: 

938 """Resolve llama-cpp-python's GGML_TYPE_* constants, or None on older builds.""" 

939 try: 

940 from llama_cpp import llama_cpp as _llc 

941 except Exception: # pragma: no cover -- only fires on llama-cpp-python without _llc 

942 return None 

943 return { 

944 KvCacheType.F32: getattr(_llc, "GGML_TYPE_F32", None), 

945 KvCacheType.F16: getattr(_llc, "GGML_TYPE_F16", None), 

946 KvCacheType.Q8_0: getattr(_llc, "GGML_TYPE_Q8_0", None), 

947 KvCacheType.Q4_0: getattr(_llc, "GGML_TYPE_Q4_0", None), 

948 } 

949 

950 

951def _construct_llama(llama_cls: Any, model_path: Path, kwargs: dict[str, Any]) -> Any: 

952 """Call ``llama_cls(**kwargs)`` with FA fallback and OOM-retry-with-halved-ctx. 

953 

954 Each loop iteration either returns the loaded model, raises (failure 

955 or unrelated TypeError), or continues with halved n_ctx; the loop is 

956 therefore structurally exhaustive and never falls through. 

957 """ 

958 # Fresh abort flag per load: a prior request_abort() that interrupted 

959 # an inference must not latch and abort the next model swap. 

960 clear_abort() 

961 kwargs.setdefault("abort_callback", abort_callback) 

962 fa_dropped = False 

963 for attempt in range(_MAX_OOM_RETRIES + 1): 

964 try: 

965 return suppress_native_stderr(llama_cls, **kwargs) 

966 except TypeError as exc: 

967 if not _drop_flash_attn_if_unsupported(exc, kwargs, fa_dropped): 

968 raise 

969 fa_dropped = True 

970 continue 

971 except ValueError as exc: 

972 if attempt == _MAX_OOM_RETRIES or not _is_load_oom(exc): 

973 _raise_load_error(model_path, kwargs, exc) 

974 if not _halve_ctx_for_retry(kwargs, exc): 

975 _raise_load_error(model_path, kwargs, exc) 

976 raise RuntimeError("unreachable: _construct_llama loop fell through") # pragma: no cover 

977 

978 

979def _drop_flash_attn_if_unsupported( 

980 exc: TypeError, kwargs: dict[str, Any], already_dropped: bool 

981) -> bool: 

982 """If the TypeError is about an unsupported ``flash_attn`` kwarg, drop it.""" 

983 if already_dropped or "flash_attn" not in kwargs or "flash_attn" not in str(exc): 

984 return False 

985 log.info("llama-cpp-python rejected flash_attn=True; retrying without it") 

986 kwargs.pop("flash_attn", None) 

987 return True 

988 

989 

990def _halve_ctx_for_retry(kwargs: dict[str, Any], exc: ValueError) -> bool: 

991 """Halve n_ctx (and matching batch sizes) for an OOM retry. Returns False if no progress.""" 

992 current_ctx = int(kwargs.get("n_ctx", 0) or 0) 

993 if current_ctx <= 0: 

994 return False 

995 new_ctx = max(_CTX_FLOOR, (current_ctx // 2 // _CTX_QUANTUM) * _CTX_QUANTUM) 

996 if new_ctx >= current_ctx: 

997 return False 

998 log.warning( 

999 "llama.cpp load failed at n_ctx=%d (%s); retrying at n_ctx=%d", 

1000 current_ctx, 

1001 str(exc).splitlines()[0], 

1002 new_ctx, 

1003 ) 

1004 kwargs["n_ctx"] = new_ctx 

1005 for key in ("n_batch", "n_ubatch"): 

1006 if key in kwargs: 

1007 kwargs[key] = new_ctx 

1008 return True 

1009 

1010 

1011def _raise_load_error(model_path: Path, kwargs: dict[str, Any], exc: ValueError) -> None: 

1012 """Raise the wrapped diagnostic for a llama.cpp load failure, or re-raise as-is.""" 

1013 wrapped = _wrap_llama_load_error(model_path, kwargs, exc) 

1014 if wrapped is None: 

1015 raise exc 

1016 raise wrapped from exc 

1017 

1018 

1019def _is_load_oom(exc: ValueError) -> bool: 

1020 """Does this ValueError look like a llama.cpp memory failure?""" 

1021 err = str(exc) 

1022 return "llama_context" in err or "load model from file" in err 

1023 

1024 

1025def _wrap_llama_load_error( 

1026 model_path: Path, kwargs: dict[str, Any], exc: ValueError 

1027) -> ValueError | None: 

1028 """Diagnostic ValueError for opaque llama.cpp load failures, or None to pass through.""" 

1029 err = str(exc) 

1030 if "llama_context" not in err and "load model from file" not in err: 

1031 return None 

1032 try: 

1033 size_gb = model_path.stat().st_size / (1024**3) if model_path.exists() else 0.0 

1034 except OSError: # pragma: no cover 

1035 size_gb = 0.0 

1036 n_ctx = kwargs.get("n_ctx", 0) 

1037 n_ctx_label = n_ctx or "model default" 

1038 parts = [ 

1039 f"Failed to load {model_path.name} ({size_gb:.1f} GB) with n_ctx={n_ctx_label}.", 

1040 ] 

1041 try: 

1042 import psutil 

1043 

1044 free_gb = psutil.virtual_memory().available / (1024**3) 

1045 parts.append(f"Host has {free_gb:.1f} GB free RAM.") 

1046 except Exception as psu_exc: # pragma: no cover 

1047 log.debug("psutil unavailable: %s", psu_exc) 

1048 parts.append( 

1049 "Try a smaller model, lower LILBEE_NUM_CTX, set LILBEE_KV_CACHE_TYPE=q8_0, " 

1050 "or close other processes to free RAM. " 

1051 f"(llama.cpp: {err})" 

1052 ) 

1053 return ValueError(" ".join(parts)) 

1054 

1055 

1056def _is_rerank_model(model: str) -> bool: 

1057 """Check if *model* is an exact rerank catalog entry by ref or hf_repo.""" 

1058 if not model: 

1059 return False 

1060 return is_rerank_ref(model)