Coverage for src / lilbee / providers / llama_cpp / provider.py: 100%
444 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""Llama.cpp provider: class, model loader, and path resolution."""
3from __future__ import annotations
5import contextlib
6import logging
7import threading
8from collections.abc import AsyncIterator, Callable
9from dataclasses import dataclass
10from pathlib import Path
11from typing import Any, Literal, cast, overload
13from lilbee.app.services import get_services
14from lilbee.catalog import is_rerank_ref
15from lilbee.core.config import DEFAULT_NUM_CTX, cfg
16from lilbee.core.config.enums import KV_CACHE_TYPE_BYTES, KvCacheType
17from lilbee.providers.base import ClosableIterator, LLMProvider, ProviderError, filter_options
18from lilbee.providers.llama_cpp.abort_signal import abort_callback, clear_abort
19from lilbee.providers.llama_cpp.gguf_meta import (
20 find_mmproj_for_model,
21 read_gguf_metadata,
22 train_ctx_from_meta,
23)
24from lilbee.providers.llama_cpp.log_dispatch import (
25 import_llama_cpp,
26 install_llama_log_handler,
27 suppress_native_stderr,
28)
29from lilbee.providers.model_cache import (
30 LoaderMode,
31 compute_dynamic_ctx,
32 get_available_memory,
33 kv_bytes_per_token,
34)
35from lilbee.providers.worker.chat_worker import chat_worker_main
36from lilbee.providers.worker.embed_worker import embed_worker_main
37from lilbee.providers.worker.pool import PoolRuntime, RoleAccessor
38from lilbee.providers.worker.rerank_worker import rerank_worker_main
39from lilbee.providers.worker.transport import (
40 ChatRequest,
41 OcrBackend,
42 PdfOcrRequest,
43 RerankPayload,
44 RoleConfig,
45 VisionRequest,
46 WorkerRole,
47)
48from lilbee.providers.worker.transport_pipe import WorkerCrashError, WorkerError
49from lilbee.providers.worker.vision_worker import vision_worker_main
50from lilbee.providers.worker.wire_kinds import WireKind
51from lilbee.runtime.progress import EventType, ExtractEvent
52from lilbee.vision import PageText, PdfOcrChunk, pdf_page_count
54log = logging.getLogger(__name__)
56# Vision OCR sentinel used when no per-call timeout and no ``cfg.ocr_timeout``
57# is set. 24h is effectively "no cap" for the round-trip wait loop.
58_VISION_NO_CAP_TIMEOUT_S = 86_400.0
60_LLAMA_CONTEXT_PATCH_LOCK = threading.Lock()
61"""Serialises overlapping ``_llama_n_seq_max`` callers inside one process.
63The shim mutates ``llama_cpp.internals.LlamaContext.__init__`` globally
64while the with-block is open. Worker subprocesses each load one model
65serially today, but the lock keeps the contract safe if a future caller
66loads two models concurrently.
67"""
70@contextlib.contextmanager
71def _llama_n_seq_max(n_seq_max: int) -> Any:
72 """Set ``context_params.n_seq_max`` on the next ``LlamaContext`` constructed.
74 Workaround for llama-cpp-python upstream issue #2051 (``n_seq_max``
75 not exposed as a Llama kwarg). See ``docs/architecture.md`` for the
76 full rationale and the upstream-fix removal hint.
77 """
78 from llama_cpp import internals
80 with _LLAMA_CONTEXT_PATCH_LOCK:
81 original = internals.LlamaContext.__init__
83 def patched(self: Any, *, model: Any, params: Any, verbose: bool) -> None:
84 params.n_seq_max = n_seq_max
85 original(self, model=model, params=params, verbose=verbose)
87 internals.LlamaContext.__init__ = patched # type: ignore[method-assign,assignment]
88 try:
89 yield
90 finally:
91 internals.LlamaContext.__init__ = original # type: ignore[method-assign]
94# Cap on tokens drained during ``_PoolChatStreamIterator.close()`` after a
95# mid-stream cancel. A runaway model (Qwen3-0.6B stuck in a never-closing
96# ``<think>`` loop) would otherwise block close() indefinitely.
97_CHAT_STREAM_DRAIN_CAP = 1024
99# Chat-load OOM retry knobs. The OOM wrapper halves ``n_ctx`` (rounded down to
100# the next ``_CTX_QUANTUM`` multiple) up to ``_MAX_OOM_RETRIES`` times before
101# raising. ``_CTX_FLOOR`` is the smallest ``n_ctx`` we'll attempt.
102_MAX_OOM_RETRIES = 2
103_CTX_QUANTUM = 256
104_CTX_FLOOR = 512
106# Sentinel passed to ``llama-cpp-python`` for "offload all layers".
107_N_GPU_LAYERS_AUTO = -1
110# Settings baked into Llama() at load time, or whose change picks a
111# different model file. Sampling params are read per-call and excluded.
112LOAD_AFFECTING_KEYS = frozenset(
113 {
114 "num_ctx",
115 "chat_model",
116 "embedding_model",
117 "vision_model",
118 "reranker_model",
119 }
120)
122# Subset of LOAD_AFFECTING_KEYS whose change is observed by the worker on the
123# next per-call ``request.model`` (chat_worker / vision_worker check the path
124# in ``_ensure_loaded`` and reload in place). For these, the parent does not
125# need to release the pool role; the next call swaps the model inside the live
126# worker, saving the 1-3 s spawn cost.
127PER_CALL_RELOADABLE_KEYS = frozenset({"chat_model", "vision_model"})
130class LlamaCppProvider(LLMProvider):
131 """Provider backed by llama-cpp-python for local GGUF model inference."""
133 def __init__(self) -> None:
134 self._pool_lock = threading.Lock()
135 self._registered_roles: set[WorkerRole] = set()
137 @staticmethod
138 def _worker_error_message(role_label: str, exc: WorkerError) -> str:
139 """Render a user-facing message that names the role and points at the log.
141 ``WorkerCrashError`` already embeds the log path in its message; for
142 plain ``WorkerError`` (the worker reported an exception or returned
143 a malformed reply) the surfaced text is the worker's exception
144 repr so the user sees enough to file a bug report.
145 """
146 detail = str(exc)
147 if detail.endswith("."): # the wrapper supplies its own sentence-final period
148 detail = detail[:-1]
149 if isinstance(exc, WorkerCrashError):
150 return f"{role_label} worker exited unexpectedly. {detail}. Please try again."
151 return f"{role_label} worker reported an error: {detail}. Please try again."
153 def _pool_runtime(self) -> PoolRuntime:
154 """Return the Services-owned :class:`PoolRuntime`, starting it lazily."""
155 runtime = get_services().pool_runtime
156 runtime.start()
157 return runtime
159 def _get_pool_accessor(
160 self,
161 role: WorkerRole,
162 worker_main: Any,
163 config_factory: Callable[[], RoleConfig],
164 ) -> RoleAccessor:
165 """Register *role* on the Services pool the first time it is used.
167 Subsequent calls return the same accessor without touching the
168 pool state. Registration is gated by ``self._pool_lock`` so two
169 concurrent first-callers do not race to register the role twice.
170 """
171 pool = get_services().worker_pool
172 with self._pool_lock:
173 if role not in self._registered_roles:
174 pool.register(role, worker_main, config_factory)
175 self._registered_roles.add(role)
176 return pool.accessor(role)
178 def embed(self, texts: list[str]) -> list[list[float]]:
179 """Embed texts via the persistent pool worker.
181 Worker crashes and timeouts surface as :class:`ProviderError`;
182 the pool respawns the embed role lazily on the next call.
183 """
184 accessor = self._get_pool_accessor(
185 WorkerRole.EMBED, embed_worker_main, _make_role_config_factory(WorkerRole.EMBED)
186 )
187 runtime = self._pool_runtime()
188 try:
189 result = runtime.run_sync(
190 accessor.call(WireKind.EMBED, texts, timeout=cfg.worker_pool_call_timeout_s),
191 timeout=cfg.worker_pool_call_timeout_s,
192 )
193 if not isinstance(result, list):
194 raise WorkerError(
195 "ProtocolError",
196 f"Pool embed returned {type(result).__name__}, expected list[list[float]].",
197 "",
198 )
199 except WorkerError as exc:
200 raise ProviderError(
201 self._worker_error_message("Embedding", exc),
202 provider="llama-cpp",
203 ) from exc
204 except TimeoutError as exc:
205 raise ProviderError(
206 "Embedding worker timed out. Please try again.",
207 provider="llama-cpp",
208 ) from exc
209 return result
211 def rerank(self, query: str, candidates: list[str]) -> list[float]:
212 """Score *candidates* by relevance to *query* via the pool worker."""
213 if not candidates:
214 return []
215 accessor = self._get_pool_accessor(
216 WorkerRole.RERANK, rerank_worker_main, _make_role_config_factory(WorkerRole.RERANK)
217 )
218 runtime = self._pool_runtime()
219 try:
220 result = runtime.run_sync(
221 accessor.call(
222 WireKind.RERANK,
223 RerankPayload(query=query, candidates=candidates),
224 timeout=cfg.worker_pool_call_timeout_s,
225 ),
226 timeout=cfg.worker_pool_call_timeout_s,
227 )
228 if not isinstance(result, list):
229 raise WorkerError(
230 "ProtocolError",
231 f"Pool rerank returned {type(result).__name__}, expected list[float].",
232 "",
233 )
234 except WorkerError as exc:
235 raise ProviderError(
236 self._worker_error_message("Rerank", exc),
237 provider="llama-cpp",
238 ) from exc
239 except TimeoutError as exc:
240 raise ProviderError(
241 "Rerank worker timed out. Please try again.",
242 provider="llama-cpp",
243 ) from exc
244 return result
246 def supports_rerank(self) -> bool:
247 """llama-cpp can rerank iff llama-cpp-python exposes the rank pooling type."""
248 return _llama_cpp_has_rank_pooling()
250 def vision_ocr(
251 self, png_bytes: bytes, model: str, prompt: str = "", *, timeout: float | None = None
252 ) -> str:
253 """Run vision OCR via the persistent pool worker."""
254 accessor = self._get_pool_accessor(
255 WorkerRole.VISION, vision_worker_main, _make_role_config_factory(WorkerRole.VISION)
256 )
257 runtime = self._pool_runtime()
258 budget = self._vision_call_budget(timeout)
259 request = VisionRequest(png_bytes=png_bytes, prompt=prompt, model=model or None)
260 try:
261 result = runtime.run_sync(
262 accessor.call(WireKind.VISION, request, timeout=budget),
263 timeout=budget,
264 )
265 if not isinstance(result, str):
266 raise WorkerError(
267 "ProtocolError",
268 f"Pool vision_ocr returned {type(result).__name__}, expected str.",
269 "",
270 )
271 except WorkerError as exc:
272 raise ProviderError(
273 self._worker_error_message("Vision", exc),
274 provider="llama-cpp",
275 ) from exc
276 except TimeoutError as exc:
277 raise ProviderError(
278 "Vision worker timed out. Please try again.",
279 provider="llama-cpp",
280 ) from exc
281 return result
283 @staticmethod
284 def _vision_call_budget(timeout: float | None) -> float:
285 """Wall-clock budget for one vision_ocr call (per-call > cfg.ocr_timeout > no cap)."""
286 effective = timeout if timeout is not None else cfg.ocr_timeout
287 return float(effective) if effective and effective > 0 else _VISION_NO_CAP_TIMEOUT_S
289 def pdf_ocr(
290 self,
291 path: Path,
292 *,
293 backend: OcrBackend,
294 model: str = "",
295 per_page_timeout_s: float | None = None,
296 quiet: bool = True,
297 on_progress: Callable[..., None] | None = None,
298 ) -> list[PageText]:
299 """Run multi-page vision PDF OCR via the persistent vision worker.
301 ``per_page_timeout_s`` is *per page*. The total wall-clock cap on
302 the streamed drain is ``pages * per_page + cfg.vision_load_budget_s``
303 (load grace), so a 100-page scan with a 60 s per-page budget gets
304 ~6000 s + load, not 60 s for the whole document.
305 """
306 accessor = self._get_pool_accessor(
307 WorkerRole.VISION, vision_worker_main, _make_role_config_factory(WorkerRole.VISION)
308 )
309 runtime = self._pool_runtime()
310 budget = self._pdf_drain_budget(path, per_page_timeout_s)
311 del quiet # accepted for Protocol parity; worker has no Rich progress to suppress.
312 request = PdfOcrRequest(
313 path=str(path),
314 backend=backend,
315 model=model,
316 )
317 progress = on_progress
319 async def _drain() -> list[PageText]:
320 pages: list[PageText] = []
321 stream = cast(AsyncIterator[Any], accessor.stream(WireKind.PDF_OCR, request))
322 async for frame in stream:
323 if not isinstance(frame, PdfOcrChunk):
324 raise ProviderError(
325 f"PDF OCR worker streamed unexpected frame type {type(frame).__name__}.",
326 provider="llama-cpp",
327 )
328 pages.append(PageText(frame.page, frame.text))
329 if progress is not None:
330 progress(
331 EventType.EXTRACT,
332 ExtractEvent(file=path.name, page=frame.page, total_pages=frame.total),
333 )
334 return pages
336 try:
337 return runtime.run_sync(_drain(), timeout=budget)
338 except WorkerError as exc:
339 raise ProviderError(
340 self._worker_error_message("PDF OCR", exc),
341 provider="llama-cpp",
342 ) from exc
343 except TimeoutError as exc:
344 raise ProviderError(
345 "PDF OCR worker timed out. Please try again.",
346 provider="llama-cpp",
347 ) from exc
349 def _pdf_drain_budget(self, path: Path, per_page_timeout_s: float | None) -> float:
350 """Total drain timeout = page_count * per_page + vision_load_budget_s."""
351 if not per_page_timeout_s or per_page_timeout_s <= 0:
352 return _VISION_NO_CAP_TIMEOUT_S
353 try:
354 pages = pdf_page_count(path)
355 except Exception:
356 # If we can't probe pages upfront the worker will still try,
357 # but we lose the precise budget; fall back to no-cap so the
358 # parent doesn't kill a valid run on a probe failure.
359 return _VISION_NO_CAP_TIMEOUT_S
360 return float(pages) * per_page_timeout_s + cfg.vision_load_budget_s
362 @overload
363 def chat(
364 self,
365 messages: list[dict[str, str]],
366 *,
367 stream: Literal[False] = False,
368 options: dict[str, Any] | None = None,
369 model: str | None = None,
370 ) -> str: ...
372 @overload
373 def chat(
374 self,
375 messages: list[dict[str, str]],
376 *,
377 stream: Literal[True],
378 options: dict[str, Any] | None = None,
379 model: str | None = None,
380 ) -> ClosableIterator[str]: ...
382 def chat(
383 self,
384 messages: list[dict[str, str]],
385 *,
386 stream: bool = False,
387 options: dict[str, Any] | None = None,
388 model: str | None = None,
389 ) -> str | ClosableIterator[str]:
390 """Chat completion via the persistent pool worker.
392 Streaming returns a :class:`ClosableIterator[str]` whose
393 ``close()`` flips the worker's abort flag so in-flight generation
394 drains cleanly. Non-streaming returns the assembled assistant text.
395 """
396 accessor = self._get_pool_accessor(
397 WorkerRole.CHAT, chat_worker_main, _make_role_config_factory(WorkerRole.CHAT)
398 )
399 runtime = self._pool_runtime()
400 accessor.clear_abort() # honor mid-stream cancels from the previous turn
401 request = ChatRequest(
402 messages=messages,
403 stream=stream,
404 options=self._chat_kwargs_from_options(options) or None,
405 model=model,
406 )
407 if stream:
408 return _PoolChatStreamIterator(
409 runtime=runtime,
410 accessor=accessor,
411 async_iter=accessor.stream(WireKind.CHAT, request),
412 )
413 try:
414 result = runtime.run_sync(
415 accessor.call(WireKind.CHAT, request, timeout=cfg.worker_pool_call_timeout_s),
416 timeout=cfg.worker_pool_call_timeout_s,
417 )
418 if not isinstance(result, str):
419 raise WorkerError(
420 "ProtocolError",
421 f"Pool chat returned {type(result).__name__}, expected str.",
422 "",
423 )
424 except WorkerError as exc:
425 raise ProviderError(
426 self._worker_error_message("Chat", exc),
427 provider="llama-cpp",
428 ) from exc
429 except TimeoutError as exc:
430 raise ProviderError(
431 "Chat worker timed out. Please try again.",
432 provider="llama-cpp",
433 ) from exc
434 return result
436 @staticmethod
437 def _chat_kwargs_from_options(options: dict[str, Any] | None) -> dict[str, Any]:
438 """Translate user-facing options into llama-cpp create_chat_completion kwargs."""
439 if not options:
440 return {}
441 filtered = filter_options(options)
442 if "num_predict" in filtered:
443 filtered["max_tokens"] = filtered.pop("num_predict")
444 filtered.pop("num_ctx", None) # model-load param, not per-call
445 return filtered
447 def list_models(self) -> list[str]:
448 """List installed models from registry."""
449 registry = get_services().registry
450 return sorted(m.ref for m in registry.list_installed())
452 def list_chat_models(self, provider: str) -> list[str]:
453 """llama-cpp has no frontier-provider catalog; always ``[]``."""
454 return []
456 def pull_model(self, model: str, *, on_progress: Callable[..., Any] | None = None) -> None:
457 """Not supported directly: ``lilbee.catalog`` handles downloads."""
458 raise NotImplementedError(
459 f"llama-cpp provider cannot pull model {model!r}. "
460 "Download GGUF files manually or use the catalog."
461 )
463 def show_model(self, model: str) -> dict[str, Any] | None:
464 """Return model metadata from GGUF headers."""
465 try:
466 path = resolve_model_path(model)
467 except ProviderError:
468 return None
469 return read_gguf_metadata(path)
471 def get_capabilities(self, model: str) -> list[str]:
472 """Detect capabilities from local GGUF files.
474 Rerank models return ``["rerank"]``; cross-encoder GGUFs cannot
475 generate text. Other models report ``"completion"``, plus
476 ``"vision"`` when an mmproj sidecar is present.
477 """
478 if _is_rerank_model(model):
479 return ["rerank"]
480 caps: list[str] = ["completion"]
481 try:
482 path = resolve_model_path(model)
483 except ProviderError:
484 log.debug("resolve_model_path failed for %s", model, exc_info=True)
485 return caps
486 try:
487 find_mmproj_for_model(path)
488 caps.append("vision")
489 except ProviderError:
490 log.debug("no mmproj for %s", model, exc_info=True)
491 return caps
493 def warm_up_pool(self) -> None:
494 """Register roles for every configured model. Idempotent.
496 Called by ``Services`` when ``cfg.worker_pool_eager_start`` is on so
497 ``WorkerPool.start_eager()`` has roles to spawn. Roles whose model is
498 unset are skipped; this lets a setup with only ``chat_model`` +
499 ``embedding_model`` configured eager-start exactly those two and not
500 pay rerank or vision spawn cost.
501 """
502 for role, _spec in _ROLE_SPECS.items():
503 if not _is_role_configured(role):
504 continue
505 entrypoint = _ROLE_ENTRYPOINTS[role]
506 self._get_pool_accessor(role, entrypoint, _make_role_config_factory(role))
508 def shutdown(self) -> None:
509 """Drop pool registrations so a follow-up provider can re-register cleanly."""
510 self._release_pool_roles()
512 def _release_pool_roles(self) -> None:
513 """Drop our registrations on the Services pool so the next call respawns.
515 Safe even when Services has not yet been built (early shutdown
516 on import-time failure). Holds ``self._pool_lock`` so a concurrent
517 ``_get_pool_accessor`` does not race the role removal.
518 """
519 with self._pool_lock:
520 roles = tuple(self._registered_roles)
521 self._registered_roles.clear()
522 if not roles:
523 return
524 from lilbee.providers.worker.pool import PoolShutdownError
526 services = get_services()
527 runtime = services.pool_runtime
528 for role in roles:
529 try:
530 runtime.run_sync(services.worker_pool.release(role), timeout=10.0)
531 except PoolShutdownError:
532 # Pool already shut down (atexit ordering during a CLI exit
533 # tears down the pool runtime before this provider). Nothing
534 # to release; silent no-op.
535 pass
536 except (TimeoutError, RuntimeError, OSError) as exc:
537 log.warning("Pool release of role=%s raised %s", role, exc)
539 def invalidate_load_cache(self, model_path: Path | None = None) -> None:
540 """Drop the pool's per-role workers so the next call respawns with current settings.
542 The ``model_path`` argument is accepted for protocol parity with
543 other providers but does not narrow the scope: workers reload all
544 their roles on respawn anyway.
545 """
546 del model_path
547 self._release_pool_roles()
550class _PoolChatStreamIterator:
551 """Sync facade over an async chat-stream iterator from the worker pool.
553 Each ``__next__`` submits one ``__anext__`` to the pool's runtime
554 loop and blocks for the result. ``close()`` flips the worker's abort
555 flag so any in-flight generation stops at the next token-tick;
556 in-flight chunks already in the pipe still drain.
557 """
559 def __init__(
560 self,
561 *,
562 runtime: PoolRuntime,
563 accessor: RoleAccessor,
564 async_iter: Any,
565 ) -> None:
566 self._runtime = runtime
567 self._accessor = accessor
568 self._async_iter = async_iter
569 self._exhausted = False
571 def __iter__(self) -> _PoolChatStreamIterator:
572 return self
574 def __next__(self) -> str:
575 if self._exhausted:
576 raise StopIteration
577 try:
578 chunk: str = self._runtime.run_sync(
579 self._async_iter.__anext__(),
580 timeout=cfg.worker_pool_call_timeout_s,
581 )
582 return chunk
583 except StopAsyncIteration:
584 self._exhausted = True
585 raise StopIteration from None
586 except WorkerError as exc:
587 # Mid-stream worker crashes propagate as ProviderError so the
588 # streaming path matches the non-streaming contract.
589 self._exhausted = True
590 raise ProviderError(
591 LlamaCppProvider._worker_error_message("Chat", exc),
592 provider="llama-cpp",
593 ) from exc
594 except TimeoutError as exc:
595 self._exhausted = True
596 raise ProviderError(
597 "Chat worker timed out mid-stream. Please try again.",
598 provider="llama-cpp",
599 ) from exc
601 def close(self) -> None:
602 """Cancel mid-stream and drain remaining tokens from the pipe.
604 Drain is bounded by ``_CHAT_STREAM_DRAIN_CAP`` so a stuck
605 worker cannot block close() indefinitely; once the cap fires we
606 accept the partial-state for not hanging the UI.
607 """
608 if self._exhausted:
609 return
610 self._accessor.cancel()
611 drained = 0
612 while drained < _CHAT_STREAM_DRAIN_CAP:
613 try:
614 next(self)
615 except StopIteration:
616 break
617 except Exception:
618 break
619 drained += 1
620 self._accessor.clear_abort()
621 self._exhausted = True
623 def __del__(self) -> None: # pragma: no cover
624 with contextlib.suppress(Exception):
625 self.close()
628@dataclass(frozen=True)
629class _RoleSpec:
630 """Per-role recipe for building a :class:`RoleConfig` from cfg."""
632 cfg_attr: str
633 mode: str
636_ROLE_SPECS: dict[WorkerRole, _RoleSpec] = {
637 WorkerRole.EMBED: _RoleSpec(cfg_attr="embedding_model", mode=LoaderMode.EMBED),
638 WorkerRole.RERANK: _RoleSpec(cfg_attr="reranker_model", mode=LoaderMode.RERANK),
639 WorkerRole.CHAT: _RoleSpec(cfg_attr="chat_model", mode=LoaderMode.CHAT),
640 # Vision uses a custom mtmd loader (not load_llama); the mode hint is
641 # documentation only, the vision worker calls load_vision_llama directly.
642 WorkerRole.VISION: _RoleSpec(cfg_attr="vision_model", mode="vision"),
643}
646_ROLE_ENTRYPOINTS: dict[WorkerRole, Callable[..., None]] = {
647 WorkerRole.EMBED: embed_worker_main,
648 WorkerRole.RERANK: rerank_worker_main,
649 WorkerRole.CHAT: chat_worker_main,
650 WorkerRole.VISION: vision_worker_main,
651}
654def _is_role_configured(role: WorkerRole) -> bool:
655 """True iff the cfg attribute for *role* holds a non-empty model name."""
656 return bool(getattr(cfg, _ROLE_SPECS[role].cfg_attr))
659def _make_role_config_factory(role: WorkerRole) -> Callable[[], RoleConfig]:
660 """Return a factory that resolves the role's configured model at spawn time.
662 The pool calls the factory on every spawn (lazy or restart) so model
663 swaps in cfg propagate without an explicit invalidation call.
664 """
665 spec = _ROLE_SPECS[role]
667 def _make() -> RoleConfig:
668 model_name = getattr(cfg, spec.cfg_attr)
669 if not model_name:
670 raise ProviderError(
671 f"No {role} model configured. Set cfg.{spec.cfg_attr} first.",
672 provider="llama-cpp",
673 )
674 return RoleConfig(
675 role=role,
676 model_path=resolve_model_path(model_name),
677 mode=spec.mode,
678 )
680 return _make
683def resolve_model_path(model: str) -> Path:
684 """Resolve a model name to a .gguf file path.
685 Resolution order:
686 1. Registry (canonical source for installed models)
687 2. Absolute path (if it points to an existing file)
688 """
689 registry = get_services().registry
690 try:
691 return registry.resolve(model)
692 except (KeyError, ValueError):
693 pass
695 # Absolute path to a .gguf file
696 candidate = Path(model)
697 if candidate.is_absolute():
698 if candidate.exists():
699 return candidate
700 raise ProviderError(f"Model file not found: {model}", provider="llama-cpp")
702 raise ProviderError(
703 f"Model {model!r} not found in registry. "
704 f"Install it via the catalog or 'lilbee model pull'.",
705 provider="llama-cpp",
706 )
709def _llama_cpp_has_rank_pooling() -> bool:
710 """Return True iff the installed llama-cpp-python exposes ``LLAMA_POOLING_TYPE_RANK``."""
711 # supports_rerank() can be called before any model load (feature detection
712 # for status / catalog UIs), so route through import_llama_cpp() first to
713 # surface the libvulkan hint rather than a raw OSError on bare Linux. A
714 # genuinely-missing llama_cpp package surfaces as ImportError and means
715 # "no rerank support"; a libvulkan-flavored OSError is a real install
716 # error and must propagate as ProviderError to the caller.
717 try:
718 import_llama_cpp()
719 from llama_cpp import LLAMA_POOLING_TYPE_RANK # noqa: F401
720 except ImportError:
721 return False
722 return True
725def load_llama(
726 model_path: Path,
727 *,
728 mode: LoaderMode,
729 abort_callback_override: Any = None,
730) -> Any:
731 """Load a llama_cpp.Llama in chat, embed, or rerank mode.
733 ``abort_callback_override`` lets pool workers bind a callback that
734 reads the worker's shared ``mp.Value`` abort flag.
735 """
736 Llama = import_llama_cpp().Llama # noqa: N806
738 install_llama_log_handler()
739 embedding = mode in (LoaderMode.EMBED, LoaderMode.RERANK)
740 kwargs: dict[str, Any] = {
741 "model_path": str(model_path),
742 "embedding": embedding,
743 "verbose": False,
744 "n_gpu_layers": _resolve_n_gpu_layers(embedding=embedding),
745 }
746 if cfg.main_gpu is not None:
747 kwargs["main_gpu"] = cfg.main_gpu
749 if embedding:
750 # Embedding/rerank uses the model's training context unconditionally.
751 # cfg.num_ctx is a chat-tuned setting; propagating it here used to
752 # clamp the rerank model below what a query+candidate pair needs and
753 # produced "llama_decode returned 1" on every other query when the
754 # user picked a small chat ctx for a low-RAM box. The explicit
755 # ``embed_train_ctx`` value (instead of ``0`` for "use model
756 # default") keeps the OOM-retry path working: ``_halve_ctx_for_retry``
757 # cannot bisect from 0.
758 embed_meta = _safe_read_gguf_metadata(model_path)
759 embed_train_ctx = train_ctx_from_meta(
760 embed_meta, fallback=_EMBED_FALLBACK_CTX, model_path=model_path
761 )
762 kwargs["n_ctx"] = embed_train_ctx
763 elif cfg.num_ctx is not None:
764 kwargs["n_ctx"] = cfg.num_ctx
765 else:
766 meta = _safe_read_gguf_metadata(model_path)
767 kwargs["n_ctx"] = _resolve_chat_ctx(model_path, meta)
768 log.info(
769 "Chat n_ctx=%d for %s (dynamic, training_ctx=%s)",
770 kwargs["n_ctx"],
771 model_path.name,
772 (meta or {}).get("context_length", "unknown"),
773 )
775 if embedding:
776 # llama-cpp-python defaults n_batch = min(n_ctx, 512), silently
777 # truncating embeddings to 512 tokens. Set n_batch = n_ctx so each
778 # text can use the model's full context window.
779 kwargs["n_batch"] = kwargs["n_ctx"]
780 kwargs["n_ubatch"] = kwargs["n_ctx"]
782 if mode == LoaderMode.RERANK:
783 from llama_cpp import LLAMA_POOLING_TYPE_RANK
785 kwargs["pooling_type"] = LLAMA_POOLING_TYPE_RANK
787 if not embedding:
788 _apply_flash_attention(kwargs)
789 _apply_kv_cache_type(kwargs)
791 if abort_callback_override is not None:
792 kwargs["abort_callback"] = abort_callback_override
794 if embedding:
795 from lilbee.providers.llama_cpp.batching import EMBED_N_SEQ_MAX
797 with _llama_n_seq_max(EMBED_N_SEQ_MAX):
798 return _construct_llama(Llama, model_path, kwargs)
799 return _construct_llama(Llama, model_path, kwargs)
802def _safe_read_gguf_metadata(model_path: Path) -> dict[str, str] | None:
803 """Best-effort GGUF metadata read, returning None on any failure."""
804 try:
805 return read_gguf_metadata(model_path)
806 except Exception:
807 log.debug("read_gguf_metadata failed for %s", model_path, exc_info=True)
808 return None
811# Fallback used when an embedding GGUF reports zero, negative, or
812# unparseable ``context_length`` in its metadata header. Some published
813# nomic-embed and Qwen3 GGUFs in the wild report ``0`` (the b473 QA dump
814# logged ``n_ctx_seq (512) > n_ctx_train (0)``). 2048 is the documented
815# training-context for the smallest featured embedder that uses it
816# (Google's EmbeddingGemma-300m, see
817# https://huggingface.co/google/embeddinggemma-300m), and llama.cpp
818# tolerates n_ctx > n_ctx_train with a warning, so the larger nomic
819# embedder still loads cleanly under the same fallback.
820_EMBED_FALLBACK_CTX = 2048
823def _resolve_chat_ctx(model_path: Path, meta: dict[str, str] | None) -> int:
824 """Pick the largest 256-multiple n_ctx that fits in available memory."""
825 training_ctx = train_ctx_from_meta(meta, fallback=DEFAULT_NUM_CTX, model_path=model_path)
826 ceiling = cfg.num_ctx_max
828 try:
829 model_bytes = model_path.stat().st_size
830 available = get_available_memory(cfg.gpu_memory_fraction)
831 kv_per_tok = kv_bytes_per_token(meta, _kv_elem_bytes_for_cfg())
832 return compute_dynamic_ctx(
833 model_bytes=model_bytes,
834 available_bytes=available,
835 training_ctx=training_ctx,
836 kv_bytes_per_tok=kv_per_tok,
837 ceiling=ceiling,
838 )
839 except (OSError, ValueError):
840 log.debug("dynamic ctx sizing failed for %s, using static cap", model_path, exc_info=True)
841 return min(training_ctx, DEFAULT_NUM_CTX)
844def _kv_elem_bytes_for_cfg() -> int:
845 """Bytes per KV element implied by the configured cache type."""
846 return KV_CACHE_TYPE_BYTES[cfg.kv_cache_type]
849def _resolve_n_gpu_layers(*, embedding: bool) -> int:
850 """Resolve ``cfg.n_gpu_layers`` (None=all) to llama-cpp's offload integer."""
851 if embedding or cfg.n_gpu_layers is None:
852 return _N_GPU_LAYERS_AUTO
853 return cfg.n_gpu_layers
856def _apply_flash_attention(kwargs: dict[str, Any]) -> None:
857 """Set ``flash_attn`` per ``cfg.flash_attention`` (None=auto, True/False=force)."""
858 if cfg.flash_attention is False:
859 return
860 # None (auto) and True both pass flash_attn=True; the construct loop
861 # drops it on TypeError if llama-cpp-python doesn't support it.
862 kwargs["flash_attn"] = True
865def _apply_kv_cache_type(kwargs: dict[str, Any]) -> None:
866 """Map ``cfg.kv_cache_type`` to llama-cpp-python ``type_k`` / ``type_v``."""
867 if cfg.kv_cache_type is KvCacheType.F16:
868 return
869 type_map = _ggml_type_map()
870 if type_map is None:
871 log.debug("llama_cpp internal types unavailable; skipping KV quant")
872 return
873 ggml_type = type_map.get(cfg.kv_cache_type)
874 if ggml_type is None: # pragma: no cover -- defensive against new enum values
875 return
876 kwargs["type_k"] = ggml_type
877 kwargs["type_v"] = ggml_type
880def _ggml_type_map() -> dict[KvCacheType, Any] | None:
881 """Resolve llama-cpp-python's GGML_TYPE_* constants, or None on older builds."""
882 try:
883 from llama_cpp import llama_cpp as _llc
884 except Exception: # pragma: no cover -- only fires on llama-cpp-python without _llc
885 return None
886 return {
887 KvCacheType.F32: getattr(_llc, "GGML_TYPE_F32", None),
888 KvCacheType.F16: getattr(_llc, "GGML_TYPE_F16", None),
889 KvCacheType.Q8_0: getattr(_llc, "GGML_TYPE_Q8_0", None),
890 KvCacheType.Q4_0: getattr(_llc, "GGML_TYPE_Q4_0", None),
891 }
894def _construct_llama(llama_cls: Any, model_path: Path, kwargs: dict[str, Any]) -> Any:
895 """Call ``llama_cls(**kwargs)`` with FA fallback and OOM-retry-with-halved-ctx.
897 Each loop iteration either returns the loaded model, raises (failure
898 or unrelated TypeError), or continues with halved n_ctx; the loop is
899 therefore structurally exhaustive and never falls through.
900 """
901 # Fresh abort flag per load: a prior request_abort() that interrupted
902 # an inference must not latch and abort the next model swap.
903 clear_abort()
904 kwargs.setdefault("abort_callback", abort_callback)
905 fa_dropped = False
906 for attempt in range(_MAX_OOM_RETRIES + 1):
907 try:
908 return suppress_native_stderr(llama_cls, **kwargs)
909 except TypeError as exc:
910 if not _drop_flash_attn_if_unsupported(exc, kwargs, fa_dropped):
911 raise
912 fa_dropped = True
913 continue
914 except ValueError as exc:
915 if attempt == _MAX_OOM_RETRIES or not _is_load_oom(exc):
916 _raise_load_error(model_path, kwargs, exc)
917 if not _halve_ctx_for_retry(kwargs, exc):
918 _raise_load_error(model_path, kwargs, exc)
919 raise RuntimeError("unreachable: _construct_llama loop fell through") # pragma: no cover
922def _drop_flash_attn_if_unsupported(
923 exc: TypeError, kwargs: dict[str, Any], already_dropped: bool
924) -> bool:
925 """If the TypeError is about an unsupported ``flash_attn`` kwarg, drop it."""
926 if already_dropped or "flash_attn" not in kwargs or "flash_attn" not in str(exc):
927 return False
928 log.info("llama-cpp-python rejected flash_attn=True; retrying without it")
929 kwargs.pop("flash_attn", None)
930 return True
933def _halve_ctx_for_retry(kwargs: dict[str, Any], exc: ValueError) -> bool:
934 """Halve n_ctx (and matching batch sizes) for an OOM retry. Returns False if no progress."""
935 current_ctx = int(kwargs.get("n_ctx", 0) or 0)
936 if current_ctx <= 0:
937 return False
938 new_ctx = max(_CTX_FLOOR, (current_ctx // 2 // _CTX_QUANTUM) * _CTX_QUANTUM)
939 if new_ctx >= current_ctx:
940 return False
941 log.warning(
942 "llama.cpp load failed at n_ctx=%d (%s); retrying at n_ctx=%d",
943 current_ctx,
944 str(exc).splitlines()[0],
945 new_ctx,
946 )
947 kwargs["n_ctx"] = new_ctx
948 for key in ("n_batch", "n_ubatch"):
949 if key in kwargs:
950 kwargs[key] = new_ctx
951 return True
954def _raise_load_error(model_path: Path, kwargs: dict[str, Any], exc: ValueError) -> None:
955 """Raise the wrapped diagnostic for a llama.cpp load failure, or re-raise as-is."""
956 wrapped = _wrap_llama_load_error(model_path, kwargs, exc)
957 if wrapped is None:
958 raise exc
959 raise wrapped from exc
962def _is_load_oom(exc: ValueError) -> bool:
963 """Does this ValueError look like a llama.cpp memory failure?"""
964 err = str(exc)
965 return "llama_context" in err or "load model from file" in err
968def _wrap_llama_load_error(
969 model_path: Path, kwargs: dict[str, Any], exc: ValueError
970) -> ValueError | None:
971 """Diagnostic ValueError for opaque llama.cpp load failures, or None to pass through."""
972 err = str(exc)
973 if "llama_context" not in err and "load model from file" not in err:
974 return None
975 try:
976 size_gb = model_path.stat().st_size / (1024**3) if model_path.exists() else 0.0
977 except OSError: # pragma: no cover
978 size_gb = 0.0
979 n_ctx = kwargs.get("n_ctx", 0)
980 n_ctx_label = n_ctx or "model default"
981 parts = [
982 f"Failed to load {model_path.name} ({size_gb:.1f} GB) with n_ctx={n_ctx_label}.",
983 ]
984 try:
985 import psutil
987 free_gb = psutil.virtual_memory().available / (1024**3)
988 parts.append(f"Host has {free_gb:.1f} GB free RAM.")
989 except Exception as psu_exc: # pragma: no cover
990 log.debug("psutil unavailable: %s", psu_exc)
991 parts.append(
992 "Try a smaller model, lower LILBEE_NUM_CTX, set LILBEE_KV_CACHE_TYPE=q8_0, "
993 "or close other processes to free RAM. "
994 f"(llama.cpp: {err})"
995 )
996 return ValueError(" ".join(parts))
999def _is_rerank_model(model: str) -> bool:
1000 """Check if *model* is an exact rerank catalog entry by ref or hf_repo."""
1001 if not model:
1002 return False
1003 return is_rerank_ref(model)