Coverage for src / lilbee / providers / base.py: 100%
34 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""Base protocol and exceptions for LLM providers."""
3from __future__ import annotations
5from collections.abc import Callable, Iterator
6from enum import StrEnum
7from pathlib import Path
8from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, overload, runtime_checkable
10from pydantic import BaseModel
12if TYPE_CHECKING:
13 from lilbee.providers.worker.transport import OcrBackend
14 from lilbee.vision import PageText
16T_co = TypeVar("T_co", covariant=True)
19@runtime_checkable
20class ClosableIterator(Iterator[T_co], Protocol[T_co]):
21 """An iterator that releases resources when ``close()`` is called.
23 Streaming chat responses use this to guarantee the upstream model lock
24 is released even when callers truncate the stream before exhaustion.
25 Generators satisfy this implicitly; explicit wrappers (e.g. the llama-cpp
26 chat-lock iterator) implement it directly.
27 """
29 def close(self) -> None: ...
32class LLMOptions(BaseModel):
33 """Validated options passed to LLM providers.
34 Only these fields are forwarded: everything else is rejected
35 to prevent injection of sensitive parameters like api_base or api_key.
36 """
38 temperature: float | None = None
39 top_p: float | None = None
40 top_k: int | None = None
41 seed: int | None = None
42 num_predict: int | None = None
43 repeat_penalty: float | None = None
44 num_ctx: int | None = None
46 def to_dict(self) -> dict[str, Any]:
47 """Return only non-None values as a dict."""
48 return {k: v for k, v in self.model_dump().items() if v is not None}
51def filter_options(options: dict[str, Any]) -> dict[str, Any]:
52 """Validate and filter generation options through LLMOptions model."""
53 return LLMOptions(**options).to_dict()
56class ProviderErrorKind(StrEnum):
57 """Provider-agnostic category of a failed provider call.
59 Classified by exception type at each backend boundary so callers can
60 branch on the kind instead of matching message strings (which are
61 provider-specific and drift between SDK versions).
62 """
64 AUTH = "auth"
65 RATE_LIMIT = "rate_limit"
66 CONTEXT_OVERFLOW = "context_overflow"
67 NOT_FOUND = "not_found"
68 BAD_REQUEST = "bad_request"
69 CONNECTION = "connection"
70 SERVER = "server"
71 UNKNOWN = "unknown"
74class ProviderError(Exception):
75 """Raised when an LLM provider operation fails.
77 ``kind`` is the provider-agnostic category; backends that can't classify a
78 failure leave it ``UNKNOWN``.
79 """
81 def __init__(
82 self,
83 message: str,
84 *,
85 provider: str = "",
86 kind: ProviderErrorKind = ProviderErrorKind.UNKNOWN,
87 ) -> None:
88 self.provider = provider
89 self.kind = kind
90 super().__init__(message)
93ChatMessage = dict[str, str]
96class LLMProvider(Protocol):
97 """Protocol for pluggable LLM backends."""
99 def embed(self, texts: list[str]) -> list[list[float]]:
100 """Embed a batch of texts, return list of vectors."""
101 ...
103 @overload
104 def chat(
105 self,
106 messages: list[ChatMessage],
107 *,
108 stream: Literal[False] = False,
109 options: dict[str, Any] | None = None,
110 model: str | None = None,
111 ) -> str: ...
113 @overload
114 def chat(
115 self,
116 messages: list[ChatMessage],
117 *,
118 stream: Literal[True],
119 options: dict[str, Any] | None = None,
120 model: str | None = None,
121 ) -> ClosableIterator[str]: ...
123 def chat(
124 self,
125 messages: list[ChatMessage],
126 *,
127 stream: bool = False,
128 options: dict[str, Any] | None = None,
129 model: str | None = None,
130 ) -> str | ClosableIterator[str]:
131 """Chat completion. Returns str for non-stream, ClosableIterator[str] for stream."""
132 ...
134 def vision_ocr(
135 self,
136 png_bytes: bytes,
137 model: str,
138 prompt: str = "",
139 *,
140 timeout: float | None = None,
141 ) -> str:
142 """OCR one page image; ``timeout`` seconds, ``None``/``0`` = no cap."""
143 ...
145 def pdf_ocr(
146 self,
147 path: Path,
148 *,
149 backend: OcrBackend,
150 model: str = "",
151 per_page_timeout_s: float | None = None,
152 quiet: bool = True,
153 on_progress: Callable[..., None] | None = None,
154 ) -> list[PageText]:
155 """OCR every page of a PDF, returning per-page text in input order.
157 Backends that cannot OCR scanned PDFs locally raise
158 :class:`NotImplementedError`; ingest callers catch and log this.
159 """
160 ...
162 def list_models(self) -> list[str]:
163 """List available model identifiers."""
164 ...
166 def list_chat_models(self, provider: str) -> list[str]:
167 """List frontier chat models the provider is aware of for *provider*.
169 Returns the unfiltered upstream catalog (whatever litellm
170 exposes for API providers; an empty list for backends like
171 native llama-cpp that have no notion of external catalogs).
172 """
173 ...
175 def pull_model(self, model: str, *, on_progress: Callable[..., Any] | None = None) -> None:
176 """Download a model. Raises NotImplementedError if not supported."""
177 ...
179 def show_model(self, model: str) -> dict[str, Any] | None:
180 """Return model metadata, or None if backend doesn't expose it."""
181 ...
183 def get_capabilities(self, model: str) -> list[str]:
184 """Return capability tags (e.g. ``["completion", "vision"]``) for *model*.
186 Returns an empty list when the backend does not support capability
187 reporting or the model is not found.
188 """
189 ...
191 def rerank(self, query: str, candidates: list[str]) -> list[float]:
192 """Score *candidates* for their relevance to *query*, one float per candidate.
194 The backend resolves the reranker model from ``cfg.reranker_model``.
195 Callers MUST check ``cfg.reranker_model`` is non-empty before
196 calling; use :meth:`supports_rerank` for UI-render decisions.
198 Returns: list of floats in input order, higher = more relevant.
199 Empty ``candidates`` returns ``[]``.
200 Raises :class:`ProviderError` when the backend does not support
201 reranking or ``cfg.reranker_model`` is empty.
202 """
203 ...
205 def supports_rerank(self) -> bool:
206 """Capability probe: can this backend rerank *if* a model is configured?
208 Pure capability check, NOT "a reranker is currently active". An
209 empty ``cfg.reranker_model`` returns ``True`` so the settings UI
210 keeps the picker visible; callers that need to know whether
211 reranking is actually configured must check ``bool(cfg.reranker_model)``
212 separately. ``rerank()`` is the gated path that requires a
213 non-empty value.
214 """
215 return False
217 def shutdown(self) -> None:
218 """Release resources (e.g. background threads). No-op if nothing to clean up."""
219 ...
221 def invalidate_load_cache(self, model_path: Path | None = None) -> None:
222 """Drop loaded-model state; ``None`` evicts all, else only that path. No-op default."""
223 return
225 def warm_up_pool(self) -> None:
226 """Eagerly register configured roles so :meth:`WorkerPool.start_eager` has work to do.
228 Default no-op so providers without a worker pool (SDK / routing
229 wrappers) can be passed to ``Services`` unchanged. Implemented by
230 :class:`LlamaCppProvider` to register chat / embed / rerank / vision
231 roles whose model is configured.
232 """
233 return