Coverage for src / lilbee / providers / llama_cpp / provider.py: 100%
468 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""Llama.cpp provider: class, model loader, and path resolution."""
3from __future__ import annotations
5import contextlib
6import logging
7import threading
8from collections.abc import AsyncIterator, Callable
9from dataclasses import dataclass
10from pathlib import Path
11from typing import Any, Literal, cast, overload
13from lilbee.app.services import get_services
14from lilbee.catalog import is_rerank_ref
15from lilbee.core.config import DEFAULT_NUM_CTX, cfg
16from lilbee.core.config.enums import KV_CACHE_TYPE_BYTES, KvCacheType
17from lilbee.providers.base import ClosableIterator, LLMProvider, ProviderError, filter_options
18from lilbee.providers.llama_cpp.abort_signal import abort_callback, clear_abort
19from lilbee.providers.llama_cpp.gguf_meta import (
20 find_mmproj_for_model,
21 read_gguf_metadata,
22 train_ctx_from_meta,
23)
24from lilbee.providers.llama_cpp.log_dispatch import (
25 import_llama_cpp,
26 install_llama_log_handler,
27 suppress_native_stderr,
28)
29from lilbee.providers.model_cache import (
30 LoaderMode,
31 compute_dynamic_ctx,
32 get_available_memory,
33 kv_bytes_per_token,
34)
35from lilbee.providers.worker.chat_worker import chat_worker_main
36from lilbee.providers.worker.embed_worker import embed_worker_main
37from lilbee.providers.worker.pool import PoolRuntime, RoleAccessor
38from lilbee.providers.worker.rerank_worker import rerank_worker_main
39from lilbee.providers.worker.transport import (
40 ChatRequest,
41 OcrBackend,
42 PdfOcrRequest,
43 RerankPayload,
44 RoleConfig,
45 VisionRequest,
46 WorkerRole,
47)
48from lilbee.providers.worker.transport_pipe import WorkerCrashError, WorkerError
49from lilbee.providers.worker.vision_worker import vision_worker_main
50from lilbee.providers.worker.wire_kinds import WireKind
51from lilbee.runtime.progress import EventType, ExtractEvent
52from lilbee.vision import PageText, PdfOcrChunk, pdf_page_count
54log = logging.getLogger(__name__)
56# Vision OCR sentinel used when no per-call timeout and no ``cfg.ocr_timeout``
57# is set. 24h is effectively "no cap" for the round-trip wait loop.
58_VISION_NO_CAP_TIMEOUT_S = 86_400.0
60_LLAMA_CONTEXT_PATCH_LOCK = threading.Lock()
61"""Serialises overlapping ``_llama_n_seq_max`` callers inside one process.
63The shim mutates ``llama_cpp.internals.LlamaContext.__init__`` globally
64while the with-block is open. Worker subprocesses each load one model
65serially today, but the lock keeps the contract safe if a future caller
66loads two models concurrently.
67"""
70@contextlib.contextmanager
71def _llama_n_seq_max(n_seq_max: int) -> Any:
72 """Set ``context_params.n_seq_max`` on the next ``LlamaContext`` constructed.
74 Workaround for llama-cpp-python upstream issue #2051 (``n_seq_max``
75 not exposed as a Llama kwarg). See ``docs/architecture.md`` for the
76 full rationale and the upstream-fix removal hint.
77 """
78 from llama_cpp import internals
80 with _LLAMA_CONTEXT_PATCH_LOCK:
81 original = internals.LlamaContext.__init__
83 def patched(self: Any, *, model: Any, params: Any, verbose: bool) -> None:
84 params.n_seq_max = n_seq_max
85 original(self, model=model, params=params, verbose=verbose)
87 internals.LlamaContext.__init__ = patched # type: ignore[method-assign,assignment]
88 try:
89 yield
90 finally:
91 internals.LlamaContext.__init__ = original # type: ignore[method-assign]
94# Cap on tokens drained during ``_PoolChatStreamIterator.close()`` after a
95# mid-stream cancel. A runaway model (Qwen3-0.6B stuck in a never-closing
96# ``<think>`` loop) would otherwise block close() indefinitely.
97_CHAT_STREAM_DRAIN_CAP = 1024
99# Chat-load OOM retry knobs. The OOM wrapper halves ``n_ctx`` (rounded down to
100# the next ``_CTX_QUANTUM`` multiple) up to ``_MAX_OOM_RETRIES`` times before
101# raising. ``_CTX_FLOOR`` is the smallest ``n_ctx`` we'll attempt.
102_MAX_OOM_RETRIES = 2
103_CTX_QUANTUM = 256
104_CTX_FLOOR = 512
106# Sentinel passed to ``llama-cpp-python`` for "offload all layers".
107_N_GPU_LAYERS_AUTO = -1
109# lilbee extracts a single scalar rerank score, so a multi-class reranker would
110# be silently scored on one logit. The response can't reveal the class count, so
111# it's checked at load.
112_SINGLE_CLASS_RERANK = 1
115def _read_context_n_seq_max(llm: Any) -> int | None:
116 """Read ``n_seq_max`` back off a constructed Llama, or None if unreadable.
118 AttributeError means llama-cpp restructured its internals (the signal to
119 revisit against the pin); the caller treats None as "could not verify".
120 """
121 try:
122 value = llm.context_params.n_seq_max
123 except AttributeError:
124 return None
125 return int(value) if isinstance(value, int) else None
128def _read_rerank_class_count(llm: Any) -> int | None:
129 """Read a reranker's classifier-output count, or None if unreadable."""
130 try:
131 model = llm._model.model
132 except AttributeError:
133 return None
134 from llama_cpp import llama_cpp as _llc
136 return int(_llc.llama_model_n_cls_out(model))
139def _verify_embed_n_seq_max(llm: Any, expected: int) -> None:
140 """Fail loudly if the ``_llama_n_seq_max`` shim did not take effect.
142 Without this, a future llama-cpp-python that moves ``LlamaContext``
143 internals leaves ``n_seq_max`` at the C default (1) and batched 2+ embed
144 silently returns ``llama_decode -1`` at inference time instead of here.
145 """
146 applied = _read_context_n_seq_max(llm)
147 if applied is not None and applied != expected:
148 raise ProviderError(
149 f"Embedding context n_seq_max={applied}, expected {expected}; "
150 "the llama-cpp-python n_seq_max workaround did not apply "
151 "(internals may have changed). Batched embed would fail at decode.",
152 provider="llama-cpp",
153 )
156def _verify_single_class_reranker(llm: Any) -> None:
157 """Fail loudly if the reranker emits more than one classifier output.
159 lilbee reads ``embedding[0]`` as the relevance score, which is the model's
160 only output for a single-class cross-encoder. A multi-class reranker would
161 silently score on class 0 alone, so reject it instead of producing
162 plausible-but-wrong rankings.
163 """
164 count = _read_rerank_class_count(llm)
165 if count is not None and count != _SINGLE_CLASS_RERANK:
166 raise ProviderError(
167 f"Reranker has {count} classifier outputs; lilbee supports only "
168 "single-class (n_cls_out=1) rerankers. A multi-class model would "
169 "be scored on class 0 alone and rank incorrectly.",
170 provider="llama-cpp",
171 )
174class LlamaCppProvider(LLMProvider):
175 """Provider backed by llama-cpp-python for local GGUF model inference."""
177 def __init__(self) -> None:
178 self._pool_lock = threading.Lock()
179 self._registered_roles: set[WorkerRole] = set()
181 @staticmethod
182 def _worker_error_message(role_label: str, exc: WorkerError) -> str:
183 """Render a user-facing message that names the role and points at the log.
185 ``WorkerCrashError`` already embeds the log path in its message; for
186 plain ``WorkerError`` (the worker reported an exception or returned
187 a malformed reply) the surfaced text is the worker's exception
188 repr so the user sees enough to file a bug report.
189 """
190 detail = str(exc)
191 if detail.endswith("."): # the wrapper supplies its own sentence-final period
192 detail = detail[:-1]
193 if isinstance(exc, WorkerCrashError):
194 return f"{role_label} worker exited unexpectedly. {detail}. Please try again."
195 return f"{role_label} worker reported an error: {detail}. Please try again."
197 def _pool_runtime(self) -> PoolRuntime:
198 """Return the Services-owned :class:`PoolRuntime`, starting it lazily."""
199 runtime = get_services().pool_runtime
200 runtime.start()
201 return runtime
203 def _get_pool_accessor(
204 self,
205 role: WorkerRole,
206 worker_main: Any,
207 config_factory: Callable[[], RoleConfig],
208 ) -> RoleAccessor:
209 """Register *role* on the Services pool the first time it is used.
211 Subsequent calls return the same accessor without touching the
212 pool state. Registration is gated by ``self._pool_lock`` so two
213 concurrent first-callers do not race to register the role twice.
214 """
215 pool = get_services().worker_pool
216 with self._pool_lock:
217 if role not in self._registered_roles:
218 pool.register(role, worker_main, config_factory)
219 self._registered_roles.add(role)
220 return pool.accessor(role)
222 def embed(self, texts: list[str]) -> list[list[float]]:
223 """Embed texts via the persistent pool worker.
225 Worker crashes and timeouts surface as :class:`ProviderError`;
226 the pool respawns the embed role lazily on the next call.
227 """
228 accessor = self._get_pool_accessor(
229 WorkerRole.EMBED, embed_worker_main, _make_role_config_factory(WorkerRole.EMBED)
230 )
231 runtime = self._pool_runtime()
232 try:
233 result = runtime.run_sync(
234 accessor.call(WireKind.EMBED, texts, timeout=cfg.worker_pool_call_timeout_s),
235 timeout=cfg.worker_pool_call_timeout_s,
236 )
237 if not isinstance(result, list):
238 raise WorkerError(
239 "ProtocolError",
240 f"Pool embed returned {type(result).__name__}, expected list[list[float]].",
241 "",
242 )
243 except WorkerError as exc:
244 raise ProviderError(
245 self._worker_error_message("Embedding", exc),
246 provider="llama-cpp",
247 ) from exc
248 except TimeoutError as exc:
249 raise ProviderError(
250 "Embedding worker timed out. Please try again.",
251 provider="llama-cpp",
252 ) from exc
253 return result
255 def rerank(self, query: str, candidates: list[str]) -> list[float]:
256 """Score *candidates* by relevance to *query* via the pool worker."""
257 if not candidates:
258 return []
259 accessor = self._get_pool_accessor(
260 WorkerRole.RERANK, rerank_worker_main, _make_role_config_factory(WorkerRole.RERANK)
261 )
262 runtime = self._pool_runtime()
263 try:
264 result = runtime.run_sync(
265 accessor.call(
266 WireKind.RERANK,
267 RerankPayload(query=query, candidates=candidates),
268 timeout=cfg.worker_pool_call_timeout_s,
269 ),
270 timeout=cfg.worker_pool_call_timeout_s,
271 )
272 if not isinstance(result, list):
273 raise WorkerError(
274 "ProtocolError",
275 f"Pool rerank returned {type(result).__name__}, expected list[float].",
276 "",
277 )
278 except WorkerError as exc:
279 raise ProviderError(
280 self._worker_error_message("Rerank", exc),
281 provider="llama-cpp",
282 ) from exc
283 except TimeoutError as exc:
284 raise ProviderError(
285 "Rerank worker timed out. Please try again.",
286 provider="llama-cpp",
287 ) from exc
288 return result
290 def supports_rerank(self) -> bool:
291 """llama-cpp can rerank iff llama-cpp-python exposes the rank pooling type."""
292 return _llama_cpp_has_rank_pooling()
294 def vision_ocr(
295 self, png_bytes: bytes, model: str, prompt: str = "", *, timeout: float | None = None
296 ) -> str:
297 """Run vision OCR via the persistent pool worker."""
298 accessor = self._get_pool_accessor(
299 WorkerRole.VISION, vision_worker_main, _make_role_config_factory(WorkerRole.VISION)
300 )
301 runtime = self._pool_runtime()
302 budget = self._vision_call_budget(timeout)
303 request = VisionRequest(png_bytes=png_bytes, prompt=prompt, model=model or None)
304 try:
305 result = runtime.run_sync(
306 accessor.call(WireKind.VISION, request, timeout=budget),
307 timeout=budget,
308 )
309 if not isinstance(result, str):
310 raise WorkerError(
311 "ProtocolError",
312 f"Pool vision_ocr returned {type(result).__name__}, expected str.",
313 "",
314 )
315 except WorkerError as exc:
316 raise ProviderError(
317 self._worker_error_message("Vision", exc),
318 provider="llama-cpp",
319 ) from exc
320 except TimeoutError as exc:
321 raise ProviderError(
322 "Vision worker timed out. Please try again.",
323 provider="llama-cpp",
324 ) from exc
325 return result
327 @staticmethod
328 def _vision_call_budget(timeout: float | None) -> float:
329 """Wall-clock budget for one vision_ocr call (per-call > cfg.ocr_timeout > no cap)."""
330 effective = timeout if timeout is not None else cfg.ocr_timeout
331 return float(effective) if effective and effective > 0 else _VISION_NO_CAP_TIMEOUT_S
333 def pdf_ocr(
334 self,
335 path: Path,
336 *,
337 backend: OcrBackend,
338 model: str = "",
339 per_page_timeout_s: float | None = None,
340 quiet: bool = True,
341 on_progress: Callable[..., None] | None = None,
342 ) -> list[PageText]:
343 """Run multi-page vision PDF OCR via the persistent vision worker.
345 ``per_page_timeout_s`` is *per page*. The total wall-clock cap on
346 the streamed drain is ``pages * per_page + cfg.vision_load_budget_s``
347 (load grace), so a 100-page scan with a 60 s per-page budget gets
348 ~6000 s + load, not 60 s for the whole document.
349 """
350 accessor = self._get_pool_accessor(
351 WorkerRole.VISION, vision_worker_main, _make_role_config_factory(WorkerRole.VISION)
352 )
353 runtime = self._pool_runtime()
354 budget = self._pdf_drain_budget(path, per_page_timeout_s)
355 del quiet # accepted for Protocol parity; worker has no Rich progress to suppress.
356 request = PdfOcrRequest(
357 path=str(path),
358 backend=backend,
359 model=model,
360 )
361 progress = on_progress
363 async def _drain() -> list[PageText]:
364 pages: list[PageText] = []
365 stream = cast(AsyncIterator[Any], accessor.stream(WireKind.PDF_OCR, request))
366 async for frame in stream:
367 if not isinstance(frame, PdfOcrChunk):
368 raise ProviderError(
369 f"PDF OCR worker streamed unexpected frame type {type(frame).__name__}.",
370 provider="llama-cpp",
371 )
372 pages.append(PageText(frame.page, frame.text))
373 if progress is not None:
374 progress(
375 EventType.EXTRACT,
376 ExtractEvent(file=path.name, page=frame.page, total_pages=frame.total),
377 )
378 return pages
380 try:
381 return runtime.run_sync(_drain(), timeout=budget)
382 except WorkerError as exc:
383 raise ProviderError(
384 self._worker_error_message("PDF OCR", exc),
385 provider="llama-cpp",
386 ) from exc
387 except TimeoutError as exc:
388 raise ProviderError(
389 "PDF OCR worker timed out. Please try again.",
390 provider="llama-cpp",
391 ) from exc
393 def _pdf_drain_budget(self, path: Path, per_page_timeout_s: float | None) -> float:
394 """Total drain timeout = page_count * per_page + vision_load_budget_s."""
395 if not per_page_timeout_s or per_page_timeout_s <= 0:
396 return _VISION_NO_CAP_TIMEOUT_S
397 try:
398 pages = pdf_page_count(path)
399 except Exception:
400 # If we can't probe pages upfront the worker will still try,
401 # but we lose the precise budget; fall back to no-cap so the
402 # parent doesn't kill a valid run on a probe failure.
403 return _VISION_NO_CAP_TIMEOUT_S
404 return float(pages) * per_page_timeout_s + cfg.vision_load_budget_s
406 @overload
407 def chat(
408 self,
409 messages: list[dict[str, str]],
410 *,
411 stream: Literal[False] = False,
412 options: dict[str, Any] | None = None,
413 model: str | None = None,
414 ) -> str: ...
416 @overload
417 def chat(
418 self,
419 messages: list[dict[str, str]],
420 *,
421 stream: Literal[True],
422 options: dict[str, Any] | None = None,
423 model: str | None = None,
424 ) -> ClosableIterator[str]: ...
426 def chat(
427 self,
428 messages: list[dict[str, str]],
429 *,
430 stream: bool = False,
431 options: dict[str, Any] | None = None,
432 model: str | None = None,
433 ) -> str | ClosableIterator[str]:
434 """Chat completion via the persistent pool worker.
436 Streaming returns a :class:`ClosableIterator[str]` whose
437 ``close()`` flips the worker's abort flag so in-flight generation
438 drains cleanly. Non-streaming returns the assembled assistant text.
439 """
440 accessor = self._get_pool_accessor(
441 WorkerRole.CHAT, chat_worker_main, _make_role_config_factory(WorkerRole.CHAT)
442 )
443 runtime = self._pool_runtime()
444 accessor.clear_abort() # honor mid-stream cancels from the previous turn
445 request = ChatRequest(
446 messages=messages,
447 stream=stream,
448 options=self._chat_kwargs_from_options(options) or None,
449 model=model,
450 )
451 if stream:
452 return _PoolChatStreamIterator(
453 runtime=runtime,
454 accessor=accessor,
455 async_iter=accessor.stream(WireKind.CHAT, request),
456 )
457 try:
458 result = runtime.run_sync(
459 accessor.call(WireKind.CHAT, request, timeout=cfg.worker_pool_call_timeout_s),
460 timeout=cfg.worker_pool_call_timeout_s,
461 )
462 if not isinstance(result, str):
463 raise WorkerError(
464 "ProtocolError",
465 f"Pool chat returned {type(result).__name__}, expected str.",
466 "",
467 )
468 except WorkerError as exc:
469 raise ProviderError(
470 self._worker_error_message("Chat", exc),
471 provider="llama-cpp",
472 ) from exc
473 except TimeoutError as exc:
474 raise ProviderError(
475 "Chat worker timed out. Please try again.",
476 provider="llama-cpp",
477 ) from exc
478 return result
480 @staticmethod
481 def _chat_kwargs_from_options(options: dict[str, Any] | None) -> dict[str, Any]:
482 """Translate user-facing options into llama-cpp create_chat_completion kwargs."""
483 if not options:
484 return {}
485 filtered = filter_options(options)
486 if "num_predict" in filtered:
487 filtered["max_tokens"] = filtered.pop("num_predict")
488 filtered.pop("num_ctx", None) # model-load param, not per-call
489 # top_k kept here (local llama.cpp honors it); the API path in
490 # translate_options strips it since hosted providers ignore it.
491 return filtered
493 def list_models(self) -> list[str]:
494 """List installed models from registry."""
495 registry = get_services().registry
496 return sorted(m.ref for m in registry.list_installed())
498 def list_chat_models(self, provider: str) -> list[str]:
499 """llama-cpp has no frontier-provider catalog; always ``[]``."""
500 return []
502 def pull_model(self, model: str, *, on_progress: Callable[..., Any] | None = None) -> None:
503 """Not supported directly: ``lilbee.catalog`` handles downloads."""
504 raise NotImplementedError(
505 f"llama-cpp provider cannot pull model {model!r}. "
506 "Download GGUF files manually or use the catalog."
507 )
509 def show_model(self, model: str) -> dict[str, Any] | None:
510 """Return model metadata from GGUF headers."""
511 try:
512 path = resolve_model_path(model)
513 except ProviderError:
514 return None
515 return read_gguf_metadata(path)
517 def get_capabilities(self, model: str) -> list[str]:
518 """Detect capabilities from local GGUF files.
520 Rerank models return ``["rerank"]``; cross-encoder GGUFs cannot
521 generate text. Other models report ``"completion"``, plus
522 ``"vision"`` when an mmproj sidecar is present.
523 """
524 if _is_rerank_model(model):
525 return ["rerank"]
526 caps: list[str] = ["completion"]
527 try:
528 path = resolve_model_path(model)
529 except ProviderError:
530 log.debug("resolve_model_path failed for %s", model, exc_info=True)
531 return caps
532 try:
533 find_mmproj_for_model(path)
534 caps.append("vision")
535 except ProviderError:
536 log.debug("no mmproj for %s", model, exc_info=True)
537 return caps
539 def warm_up_pool(self) -> None:
540 """Register roles for every configured model. Idempotent.
542 Called by ``Services`` when ``cfg.worker_pool_eager_start`` is on so
543 ``WorkerPool.start_eager()`` has roles to spawn. Roles whose model is
544 unset are skipped; this lets a setup with only ``chat_model`` +
545 ``embedding_model`` configured eager-start exactly those two and not
546 pay rerank or vision spawn cost.
547 """
548 for role, _spec in _ROLE_SPECS.items():
549 if not _is_role_configured(role):
550 continue
551 entrypoint = _ROLE_ENTRYPOINTS[role]
552 self._get_pool_accessor(role, entrypoint, _make_role_config_factory(role))
554 def shutdown(self) -> None:
555 """Drop pool registrations so a follow-up provider can re-register cleanly."""
556 self._release_pool_roles()
558 def _release_pool_roles(self) -> None:
559 """Drop our registrations on the Services pool so the next call respawns.
561 Safe even when Services has not yet been built (early shutdown
562 on import-time failure). Holds ``self._pool_lock`` so a concurrent
563 ``_get_pool_accessor`` does not race the role removal.
564 """
565 with self._pool_lock:
566 roles = tuple(self._registered_roles)
567 self._registered_roles.clear()
568 if not roles:
569 return
570 from lilbee.providers.worker.pool import PoolShutdownError
572 services = get_services()
573 runtime = services.pool_runtime
574 for role in roles:
575 try:
576 runtime.run_sync(services.worker_pool.release(role), timeout=10.0)
577 except PoolShutdownError:
578 # Pool already shut down (atexit ordering during a CLI exit
579 # tears down the pool runtime before this provider). Nothing
580 # to release; silent no-op.
581 pass
582 except (TimeoutError, RuntimeError, OSError) as exc:
583 log.warning("Pool release of role=%s raised %s", role, exc)
585 def invalidate_load_cache(self, model_path: Path | None = None) -> None:
586 """Drop the pool's per-role workers so the next call respawns with current settings.
588 The ``model_path`` argument is accepted for protocol parity with
589 other providers but does not narrow the scope: workers reload all
590 their roles on respawn anyway.
591 """
592 del model_path
593 self._release_pool_roles()
596class _PoolChatStreamIterator:
597 """Sync facade over an async chat-stream iterator from the worker pool.
599 Each ``__next__`` submits one ``__anext__`` to the pool's runtime
600 loop and blocks for the result. ``close()`` flips the worker's abort
601 flag so any in-flight generation stops at the next token-tick;
602 in-flight chunks already in the pipe still drain.
603 """
605 def __init__(
606 self,
607 *,
608 runtime: PoolRuntime,
609 accessor: RoleAccessor,
610 async_iter: Any,
611 ) -> None:
612 self._runtime = runtime
613 self._accessor = accessor
614 self._async_iter = async_iter
615 self._exhausted = False
617 def __iter__(self) -> _PoolChatStreamIterator:
618 return self
620 def __next__(self) -> str:
621 if self._exhausted:
622 raise StopIteration
623 try:
624 chunk: str = self._runtime.run_sync(
625 self._async_iter.__anext__(),
626 timeout=cfg.worker_pool_call_timeout_s,
627 )
628 return chunk
629 except StopAsyncIteration:
630 self._exhausted = True
631 raise StopIteration from None
632 except WorkerError as exc:
633 # Mid-stream worker crashes propagate as ProviderError so the
634 # streaming path matches the non-streaming contract.
635 self._exhausted = True
636 raise ProviderError(
637 LlamaCppProvider._worker_error_message("Chat", exc),
638 provider="llama-cpp",
639 ) from exc
640 except TimeoutError as exc:
641 self._exhausted = True
642 raise ProviderError(
643 "Chat worker timed out mid-stream. Please try again.",
644 provider="llama-cpp",
645 ) from exc
647 def close(self) -> None:
648 """Cancel mid-stream and drain remaining tokens from the pipe.
650 Drain is bounded by ``_CHAT_STREAM_DRAIN_CAP`` so a stuck
651 worker cannot block close() indefinitely; once the cap fires we
652 accept the partial-state for not hanging the UI.
653 """
654 if self._exhausted:
655 return
656 self._accessor.cancel()
657 drained = 0
658 while drained < _CHAT_STREAM_DRAIN_CAP:
659 try:
660 next(self)
661 except StopIteration:
662 break
663 except Exception:
664 break
665 drained += 1
666 self._accessor.clear_abort()
667 self._exhausted = True
669 def __del__(self) -> None: # pragma: no cover
670 with contextlib.suppress(Exception):
671 self.close()
674@dataclass(frozen=True)
675class _RoleSpec:
676 """Per-role recipe for building a :class:`RoleConfig` from cfg."""
678 cfg_attr: str
679 mode: str
682_ROLE_SPECS: dict[WorkerRole, _RoleSpec] = {
683 WorkerRole.EMBED: _RoleSpec(cfg_attr="embedding_model", mode=LoaderMode.EMBED),
684 WorkerRole.RERANK: _RoleSpec(cfg_attr="reranker_model", mode=LoaderMode.RERANK),
685 WorkerRole.CHAT: _RoleSpec(cfg_attr="chat_model", mode=LoaderMode.CHAT),
686 # Vision uses a custom mtmd loader (not load_llama); the mode hint is
687 # documentation only, the vision worker calls load_vision_llama directly.
688 WorkerRole.VISION: _RoleSpec(cfg_attr="vision_model", mode="vision"),
689}
692_ROLE_ENTRYPOINTS: dict[WorkerRole, Callable[..., None]] = {
693 WorkerRole.EMBED: embed_worker_main,
694 WorkerRole.RERANK: rerank_worker_main,
695 WorkerRole.CHAT: chat_worker_main,
696 WorkerRole.VISION: vision_worker_main,
697}
700def _is_role_configured(role: WorkerRole) -> bool:
701 """True iff the cfg attribute for *role* holds a non-empty model name."""
702 return bool(getattr(cfg, _ROLE_SPECS[role].cfg_attr))
705def _make_role_config_factory(role: WorkerRole) -> Callable[[], RoleConfig]:
706 """Return a factory that resolves the role's configured model at spawn time.
708 The pool calls the factory on every spawn (lazy or restart) so model
709 swaps in cfg propagate without an explicit invalidation call.
710 """
711 spec = _ROLE_SPECS[role]
713 def _make() -> RoleConfig:
714 model_name = getattr(cfg, spec.cfg_attr)
715 if not model_name:
716 raise ProviderError(
717 f"No {role} model configured. Set cfg.{spec.cfg_attr} first.",
718 provider="llama-cpp",
719 )
720 return RoleConfig(
721 role=role,
722 model_path=resolve_model_path(model_name),
723 mode=spec.mode,
724 )
726 return _make
729def resolve_model_path(model: str) -> Path:
730 """Resolve a model name to a .gguf file path.
731 Resolution order:
732 1. Registry (canonical source for installed models)
733 2. Absolute path (if it points to an existing file)
734 """
735 registry = get_services().registry
736 try:
737 return registry.resolve(model)
738 except (KeyError, ValueError):
739 pass
741 # Absolute path to a .gguf file
742 candidate = Path(model)
743 if candidate.is_absolute():
744 if candidate.exists():
745 return candidate
746 raise ProviderError(f"Model file not found: {model}", provider="llama-cpp")
748 raise ProviderError(
749 f"Model {model!r} not found in registry. "
750 f"Install it via the catalog or 'lilbee model pull'.",
751 provider="llama-cpp",
752 )
755def _llama_cpp_has_rank_pooling() -> bool:
756 """Return True iff the installed llama-cpp-python exposes ``LLAMA_POOLING_TYPE_RANK``."""
757 # supports_rerank() can be called before any model load (feature detection
758 # for status / catalog UIs), so route through import_llama_cpp() first to
759 # surface the libvulkan hint rather than a raw OSError on bare Linux. A
760 # genuinely-missing llama_cpp package surfaces as ImportError and means
761 # "no rerank support"; a libvulkan-flavored OSError is a real install
762 # error and must propagate as ProviderError to the caller.
763 try:
764 import_llama_cpp()
765 from llama_cpp import LLAMA_POOLING_TYPE_RANK # noqa: F401
766 except ImportError:
767 return False
768 return True
771def load_llama(
772 model_path: Path,
773 *,
774 mode: LoaderMode,
775 abort_callback_override: Any = None,
776) -> Any:
777 """Load a llama_cpp.Llama in chat, embed, or rerank mode.
779 ``abort_callback_override`` lets pool workers bind a callback that
780 reads the worker's shared ``mp.Value`` abort flag.
781 """
782 Llama = import_llama_cpp().Llama # noqa: N806
784 install_llama_log_handler()
785 embedding = mode in (LoaderMode.EMBED, LoaderMode.RERANK)
786 kwargs: dict[str, Any] = {
787 "model_path": str(model_path),
788 "embedding": embedding,
789 "verbose": False,
790 "n_gpu_layers": _resolve_n_gpu_layers(embedding=embedding),
791 }
792 if cfg.main_gpu is not None:
793 kwargs["main_gpu"] = cfg.main_gpu
795 if embedding:
796 # Embedding/rerank uses the model's training context unconditionally.
797 # cfg.num_ctx is a chat-tuned setting; propagating it here used to
798 # clamp the rerank model below what a query+candidate pair needs and
799 # produced "llama_decode returned 1" on every other query when the
800 # user picked a small chat ctx for a low-RAM box. The explicit
801 # ``embed_train_ctx`` value (instead of ``0`` for "use model
802 # default") keeps the OOM-retry path working: ``_halve_ctx_for_retry``
803 # cannot bisect from 0.
804 embed_meta = _safe_read_gguf_metadata(model_path)
805 embed_train_ctx = train_ctx_from_meta(
806 embed_meta, fallback=_EMBED_FALLBACK_CTX, model_path=model_path
807 )
808 kwargs["n_ctx"] = embed_train_ctx
809 elif cfg.num_ctx is not None:
810 kwargs["n_ctx"] = cfg.num_ctx
811 else:
812 meta = _safe_read_gguf_metadata(model_path)
813 kwargs["n_ctx"] = _resolve_chat_ctx(model_path, meta)
814 log.info(
815 "Chat n_ctx=%d for %s (dynamic, training_ctx=%s)",
816 kwargs["n_ctx"],
817 model_path.name,
818 (meta or {}).get("context_length", "unknown"),
819 )
821 if embedding:
822 # llama-cpp-python defaults n_batch = min(n_ctx, 512), silently
823 # truncating embeddings to 512 tokens. Set n_batch = n_ctx so each
824 # text can use the model's full context window.
825 kwargs["n_batch"] = kwargs["n_ctx"]
826 kwargs["n_ubatch"] = kwargs["n_ctx"]
828 if mode == LoaderMode.RERANK:
829 from llama_cpp import LLAMA_POOLING_TYPE_RANK
831 kwargs["pooling_type"] = LLAMA_POOLING_TYPE_RANK
833 if not embedding:
834 _apply_flash_attention(kwargs)
835 _apply_kv_cache_type(kwargs)
837 if abort_callback_override is not None:
838 kwargs["abort_callback"] = abort_callback_override
840 if embedding:
841 from lilbee.providers.llama_cpp.batching import EMBED_N_SEQ_MAX
843 with _llama_n_seq_max(EMBED_N_SEQ_MAX):
844 llm = _construct_llama(Llama, model_path, kwargs)
845 _verify_embed_n_seq_max(llm, EMBED_N_SEQ_MAX)
846 if mode == LoaderMode.RERANK:
847 _verify_single_class_reranker(llm)
848 return llm
849 return _construct_llama(Llama, model_path, kwargs)
852def _safe_read_gguf_metadata(model_path: Path) -> dict[str, str] | None:
853 """Best-effort GGUF metadata read, returning None on any failure."""
854 try:
855 return read_gguf_metadata(model_path)
856 except Exception:
857 log.debug("read_gguf_metadata failed for %s", model_path, exc_info=True)
858 return None
861# Fallback used when an embedding GGUF reports zero, negative, or
862# unparseable ``context_length`` in its metadata header. Some published
863# nomic-embed and Qwen3 GGUFs in the wild report ``0`` (the b473 QA dump
864# logged ``n_ctx_seq (512) > n_ctx_train (0)``). 2048 is the documented
865# training-context for the smallest featured embedder that uses it
866# (Google's EmbeddingGemma-300m, see
867# https://huggingface.co/google/embeddinggemma-300m), and llama.cpp
868# tolerates n_ctx > n_ctx_train with a warning, so the larger nomic
869# embedder still loads cleanly under the same fallback.
870_EMBED_FALLBACK_CTX = 2048
873def _resolve_chat_ctx(model_path: Path, meta: dict[str, str] | None) -> int:
874 """Pick n_ctx aiming for ``cfg.chat_n_ctx_target``, clamped to model + host.
876 When ``cfg.num_ctx_max`` is ``None`` the model's training_ctx is the only
877 ceiling, so a long-context model can grow past the target if the host
878 has the RAM to back it. Setting ``num_ctx_max`` explicitly caps below
879 training_ctx for per-host policy reasons.
880 """
881 training_ctx = train_ctx_from_meta(meta, fallback=DEFAULT_NUM_CTX, model_path=model_path)
882 ceiling = cfg.num_ctx_max if cfg.num_ctx_max is not None else training_ctx
884 try:
885 model_bytes = model_path.stat().st_size
886 available = get_available_memory(cfg.gpu_memory_fraction)
887 kv_per_tok = kv_bytes_per_token(meta, _kv_elem_bytes_for_cfg())
888 return compute_dynamic_ctx(
889 model_bytes=model_bytes,
890 available_bytes=available,
891 training_ctx=training_ctx,
892 kv_bytes_per_tok=kv_per_tok,
893 ceiling=ceiling,
894 target=cfg.chat_n_ctx_target,
895 )
896 except (OSError, ValueError):
897 log.debug("dynamic ctx sizing failed for %s, using static cap", model_path, exc_info=True)
898 return min(training_ctx, cfg.chat_n_ctx_target)
901def _kv_elem_bytes_for_cfg() -> int:
902 """Bytes per KV element implied by the configured cache type."""
903 return KV_CACHE_TYPE_BYTES[cfg.kv_cache_type]
906def _resolve_n_gpu_layers(*, embedding: bool) -> int:
907 """Resolve ``cfg.n_gpu_layers`` (None=all) to llama-cpp's offload integer."""
908 if embedding or cfg.n_gpu_layers is None:
909 return _N_GPU_LAYERS_AUTO
910 return cfg.n_gpu_layers
913def _apply_flash_attention(kwargs: dict[str, Any]) -> None:
914 """Set ``flash_attn`` per ``cfg.flash_attention`` (None=auto, True/False=force)."""
915 if cfg.flash_attention is False:
916 return
917 # None (auto) and True both pass flash_attn=True; the construct loop
918 # drops it on TypeError if llama-cpp-python doesn't support it.
919 kwargs["flash_attn"] = True
922def _apply_kv_cache_type(kwargs: dict[str, Any]) -> None:
923 """Map ``cfg.kv_cache_type`` to llama-cpp-python ``type_k`` / ``type_v``."""
924 if cfg.kv_cache_type is KvCacheType.F16:
925 return
926 type_map = _ggml_type_map()
927 if type_map is None:
928 log.debug("llama_cpp internal types unavailable; skipping KV quant")
929 return
930 ggml_type = type_map.get(cfg.kv_cache_type)
931 if ggml_type is None: # pragma: no cover -- defensive against new enum values
932 return
933 kwargs["type_k"] = ggml_type
934 kwargs["type_v"] = ggml_type
937def _ggml_type_map() -> dict[KvCacheType, Any] | None:
938 """Resolve llama-cpp-python's GGML_TYPE_* constants, or None on older builds."""
939 try:
940 from llama_cpp import llama_cpp as _llc
941 except Exception: # pragma: no cover -- only fires on llama-cpp-python without _llc
942 return None
943 return {
944 KvCacheType.F32: getattr(_llc, "GGML_TYPE_F32", None),
945 KvCacheType.F16: getattr(_llc, "GGML_TYPE_F16", None),
946 KvCacheType.Q8_0: getattr(_llc, "GGML_TYPE_Q8_0", None),
947 KvCacheType.Q4_0: getattr(_llc, "GGML_TYPE_Q4_0", None),
948 }
951def _construct_llama(llama_cls: Any, model_path: Path, kwargs: dict[str, Any]) -> Any:
952 """Call ``llama_cls(**kwargs)`` with FA fallback and OOM-retry-with-halved-ctx.
954 Each loop iteration either returns the loaded model, raises (failure
955 or unrelated TypeError), or continues with halved n_ctx; the loop is
956 therefore structurally exhaustive and never falls through.
957 """
958 # Fresh abort flag per load: a prior request_abort() that interrupted
959 # an inference must not latch and abort the next model swap.
960 clear_abort()
961 kwargs.setdefault("abort_callback", abort_callback)
962 fa_dropped = False
963 for attempt in range(_MAX_OOM_RETRIES + 1):
964 try:
965 return suppress_native_stderr(llama_cls, **kwargs)
966 except TypeError as exc:
967 if not _drop_flash_attn_if_unsupported(exc, kwargs, fa_dropped):
968 raise
969 fa_dropped = True
970 continue
971 except ValueError as exc:
972 if attempt == _MAX_OOM_RETRIES or not _is_load_oom(exc):
973 _raise_load_error(model_path, kwargs, exc)
974 if not _halve_ctx_for_retry(kwargs, exc):
975 _raise_load_error(model_path, kwargs, exc)
976 raise RuntimeError("unreachable: _construct_llama loop fell through") # pragma: no cover
979def _drop_flash_attn_if_unsupported(
980 exc: TypeError, kwargs: dict[str, Any], already_dropped: bool
981) -> bool:
982 """If the TypeError is about an unsupported ``flash_attn`` kwarg, drop it."""
983 if already_dropped or "flash_attn" not in kwargs or "flash_attn" not in str(exc):
984 return False
985 log.info("llama-cpp-python rejected flash_attn=True; retrying without it")
986 kwargs.pop("flash_attn", None)
987 return True
990def _halve_ctx_for_retry(kwargs: dict[str, Any], exc: ValueError) -> bool:
991 """Halve n_ctx (and matching batch sizes) for an OOM retry. Returns False if no progress."""
992 current_ctx = int(kwargs.get("n_ctx", 0) or 0)
993 if current_ctx <= 0:
994 return False
995 new_ctx = max(_CTX_FLOOR, (current_ctx // 2 // _CTX_QUANTUM) * _CTX_QUANTUM)
996 if new_ctx >= current_ctx:
997 return False
998 log.warning(
999 "llama.cpp load failed at n_ctx=%d (%s); retrying at n_ctx=%d",
1000 current_ctx,
1001 str(exc).splitlines()[0],
1002 new_ctx,
1003 )
1004 kwargs["n_ctx"] = new_ctx
1005 for key in ("n_batch", "n_ubatch"):
1006 if key in kwargs:
1007 kwargs[key] = new_ctx
1008 return True
1011def _raise_load_error(model_path: Path, kwargs: dict[str, Any], exc: ValueError) -> None:
1012 """Raise the wrapped diagnostic for a llama.cpp load failure, or re-raise as-is."""
1013 wrapped = _wrap_llama_load_error(model_path, kwargs, exc)
1014 if wrapped is None:
1015 raise exc
1016 raise wrapped from exc
1019def _is_load_oom(exc: ValueError) -> bool:
1020 """Does this ValueError look like a llama.cpp memory failure?"""
1021 err = str(exc)
1022 return "llama_context" in err or "load model from file" in err
1025def _wrap_llama_load_error(
1026 model_path: Path, kwargs: dict[str, Any], exc: ValueError
1027) -> ValueError | None:
1028 """Diagnostic ValueError for opaque llama.cpp load failures, or None to pass through."""
1029 err = str(exc)
1030 if "llama_context" not in err and "load model from file" not in err:
1031 return None
1032 try:
1033 size_gb = model_path.stat().st_size / (1024**3) if model_path.exists() else 0.0
1034 except OSError: # pragma: no cover
1035 size_gb = 0.0
1036 n_ctx = kwargs.get("n_ctx", 0)
1037 n_ctx_label = n_ctx or "model default"
1038 parts = [
1039 f"Failed to load {model_path.name} ({size_gb:.1f} GB) with n_ctx={n_ctx_label}.",
1040 ]
1041 try:
1042 import psutil
1044 free_gb = psutil.virtual_memory().available / (1024**3)
1045 parts.append(f"Host has {free_gb:.1f} GB free RAM.")
1046 except Exception as psu_exc: # pragma: no cover
1047 log.debug("psutil unavailable: %s", psu_exc)
1048 parts.append(
1049 "Try a smaller model, lower LILBEE_NUM_CTX, set LILBEE_KV_CACHE_TYPE=q8_0, "
1050 "or close other processes to free RAM. "
1051 f"(llama.cpp: {err})"
1052 )
1053 return ValueError(" ".join(parts))
1056def _is_rerank_model(model: str) -> bool:
1057 """Check if *model* is an exact rerank catalog entry by ref or hf_repo."""
1058 if not model:
1059 return False
1060 return is_rerank_ref(model)