Coverage for src / lilbee / providers / sdk_llm_provider.py: 100%
149 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""SDK-agnostic LLM provider implementing the public ``LLMProvider`` Protocol.
3``SdkLLMProvider`` owns the semantic layer: auth key injection, option
4translation, model-ref parsing, error wrapping, and lazy one-shot
5backend initialization (``configure_logging`` + ``inject_provider_keys``
6on first use). It speaks to the underlying SDK exclusively through an
7``LlmSdkBackend``, so swapping SDKs is a one-file adapter change.
9Zero direct SDK imports live here. The adapter owns SDK-specific
10concerns like wire-format prefixes (``ollama/``) and OpenAI content-parts
11schema for image inputs.
12"""
14from __future__ import annotations
16import logging
17import os
18from collections.abc import Callable
19from pathlib import Path
20from typing import Any, Literal, overload
22from lilbee.core.config import cfg
23from lilbee.providers.base import ClosableIterator, LLMProvider, ProviderError
24from lilbee.providers.model_ref import parse_model_ref, translate_options
25from lilbee.providers.sdk_backend import (
26 PROVIDER_KEYS,
27 CompletionRequest,
28 EmbeddingRequest,
29 LlmSdkBackend,
30 RerankRequest,
31)
32from lilbee.providers.worker.transport import OcrBackend
33from lilbee.vision import PageText
35log = logging.getLogger(__name__)
38def inject_provider_keys() -> None:
39 """Copy per-provider API keys from config into ``os.environ``.
41 OpenAI-compatible SDKs read provider-specific env vars
42 (``OPENAI_API_KEY``, ``ANTHROPIC_API_KEY``, ...) at call time. This
43 bridges lilbee's config system to that convention. Explicit env
44 vars are never overwritten so users can still override via their
45 shell.
46 """
47 for _, cfg_field, env_var, _ in PROVIDER_KEYS:
48 value = getattr(cfg, cfg_field, "")
49 if value and not os.environ.get(env_var):
50 os.environ[env_var] = value
53class SdkLLMProvider(LLMProvider):
54 """Provider that delegates SDK calls to an ``LlmSdkBackend``."""
56 def __init__(
57 self,
58 backend: LlmSdkBackend,
59 *,
60 base_url: str = "http://localhost:11434",
61 api_key: str = "",
62 ) -> None:
63 self._backend = backend
64 self._base_url = base_url.rstrip("/")
65 self._api_key = api_key
66 self._initialized = False
68 def _ensure_initialized(self) -> None:
69 """Apply one-shot backend setup before the first call.
71 Runs ``configure_logging(suppress_debug=cfg.json_mode)`` and
72 ``inject_provider_keys()`` exactly once, regardless of whether
73 the first operation is ``chat``, ``embed``, or a catalog query.
74 Both steps happen together because the backend's first SDK
75 import must see (a) the debug flag applied, and (b) per-provider
76 API keys in ``os.environ``.
77 """
78 if self._initialized:
79 return
80 try:
81 self._backend.configure_logging(suppress_debug=cfg.json_mode)
82 except (ImportError, AttributeError):
83 log.debug("backend.configure_logging failed", exc_info=True)
84 inject_provider_keys()
85 self._initialized = True
87 def embed(self, texts: list[str]) -> list[list[float]]:
88 """Embed texts via the configured backend."""
89 self._ensure_initialized()
90 ref = parse_model_ref(cfg.embedding_model)
91 request = EmbeddingRequest(
92 ref=ref,
93 inputs=texts,
94 api_base=self._base_url if ref.needs_api_base else None,
95 api_key=self._api_key or None,
96 )
97 try:
98 result = self._backend.embed(request)
99 except ProviderError:
100 raise
101 except Exception as exc:
102 raise ProviderError(
103 f"Embedding failed: {exc}", provider=self._backend.provider_name
104 ) from exc
105 return result.vectors
107 @overload
108 def chat(
109 self,
110 messages: list[dict[str, str]],
111 *,
112 stream: Literal[False] = False,
113 options: dict[str, Any] | None = None,
114 model: str | None = None,
115 ) -> str: ...
117 @overload
118 def chat(
119 self,
120 messages: list[dict[str, str]],
121 *,
122 stream: Literal[True],
123 options: dict[str, Any] | None = None,
124 model: str | None = None,
125 ) -> ClosableIterator[str]: ...
127 def chat(
128 self,
129 messages: list[dict[str, str]],
130 *,
131 stream: bool = False,
132 options: dict[str, Any] | None = None,
133 model: str | None = None,
134 ) -> str | ClosableIterator[str]:
135 """Chat completion via the configured backend."""
136 self._ensure_initialized()
137 ref = parse_model_ref(model or cfg.chat_model)
138 translated = translate_options(options, ref) if options else {}
139 request = CompletionRequest(
140 ref=ref,
141 messages=list(messages),
142 options=translated,
143 api_base=self._base_url if ref.needs_api_base else None,
144 api_key=self._api_key or None,
145 )
146 if stream:
147 return self._chat_stream(request)
148 try:
149 result = self._backend.complete(request)
150 except ProviderError:
151 raise
152 except Exception as exc:
153 raise ProviderError(
154 f"Chat failed: {exc}", provider=self._backend.provider_name
155 ) from exc
156 return result.content
158 def _chat_stream(self, request: CompletionRequest) -> ClosableIterator[str]:
159 """Yield content tokens from a streaming completion.
161 Exceptions surfaced by the backend at either call time or during
162 iteration are re-raised as ``ProviderError`` so callers always
163 see a consistent error type.
164 """
165 try:
166 stream = self._backend.complete_stream(request)
167 for chunk in stream:
168 if chunk.content:
169 yield chunk.content
170 except ProviderError:
171 raise
172 except Exception as exc:
173 raise ProviderError(
174 f"Chat failed: {exc}", provider=self._backend.provider_name
175 ) from exc
177 def vision_ocr(
178 self,
179 png_bytes: bytes,
180 model: str,
181 prompt: str = "",
182 *,
183 timeout: float | None = None,
184 ) -> str:
185 """OCR via a multipart chat completion; ``timeout`` enforced via thread pool."""
186 from lilbee.vision import OCR_PROMPT, build_vision_messages
188 messages = build_vision_messages(prompt or OCR_PROMPT, png_bytes)
189 if timeout and timeout > 0:
190 from concurrent.futures import ThreadPoolExecutor
192 with ThreadPoolExecutor(max_workers=1) as pool:
193 future = pool.submit(self.chat, messages, stream=False, model=model)
194 result = future.result(timeout=timeout)
195 else:
196 result = self.chat(messages, stream=False, model=model)
197 if not isinstance(result, str):
198 raise ProviderError(
199 f"Vision OCR returned non-text response ({type(result).__name__}).",
200 provider=self._backend.provider_name,
201 )
202 return result
204 def pdf_ocr(
205 self,
206 path: Path,
207 *,
208 backend: OcrBackend,
209 model: str = "",
210 per_page_timeout_s: float | None = None,
211 quiet: bool = True,
212 on_progress: Callable[..., None] | None = None,
213 ) -> list[PageText]:
214 """SDK backend cannot rasterise PDFs locally; ingest callers fall back."""
215 del path, backend, model, per_page_timeout_s, quiet, on_progress
216 raise NotImplementedError(
217 "Hosted models do not support scanned-PDF OCR. "
218 "Set LILBEE_VISION_MODEL to a local GGUF vision model to enable it."
219 )
221 def list_models(self) -> list[str]:
222 """List models from the backend (empty list on SDK errors)."""
223 try:
224 return self._backend.list_models(base_url=self._base_url, api_key=self._api_key)
225 except NotImplementedError:
226 return []
227 except ProviderError:
228 raise
229 except Exception as exc:
230 raise ProviderError(
231 f"Listing models failed: {exc}", provider=self._backend.provider_name
232 ) from exc
234 def list_chat_models(self, provider: str) -> list[str]:
235 """List frontier chat models known to the backend for *provider*.
237 Initializes the backend first so ``cfg.json_mode`` suppression is
238 applied before the SDK import inside the backend runs.
239 """
240 self._ensure_initialized()
241 try:
242 return self._backend.list_chat_models(provider)
243 except NotImplementedError:
244 return []
245 except ProviderError:
246 raise
247 except Exception as exc:
248 raise ProviderError(
249 f"Listing chat models failed: {exc}", provider=self._backend.provider_name
250 ) from exc
252 def pull_model(self, model: str, *, on_progress: Callable[..., Any] | None = None) -> None:
253 """Pull a model via the backend."""
254 try:
255 self._backend.pull_model(model, base_url=self._base_url, on_progress=on_progress)
256 except NotImplementedError as exc:
257 raise ProviderError(
258 f"Cannot pull model {model!r}: backend does not support pulling",
259 provider=self._backend.provider_name,
260 ) from exc
261 except ProviderError:
262 raise
263 except Exception as exc:
264 raise ProviderError(
265 f"Cannot pull model {model!r}: {exc}", provider=self._backend.provider_name
266 ) from exc
268 def show_model(self, model: str) -> dict[str, Any] | None:
269 """Return model metadata, or None when unsupported or not found."""
270 try:
271 return self._backend.show_model(model, base_url=self._base_url)
272 except NotImplementedError:
273 return None
274 except ProviderError:
275 raise
276 except Exception as exc:
277 raise ProviderError(
278 f"Showing model {model!r} failed: {exc}", provider=self._backend.provider_name
279 ) from exc
281 def get_capabilities(self, model: str) -> list[str]:
282 """Return capability tags from ``show_model`` output, or ``[]``."""
283 info = self.show_model(model)
284 if info is None:
285 return []
286 caps = info.get("capabilities", [])
287 return caps if isinstance(caps, list) else []
289 def rerank(self, query: str, candidates: list[str]) -> list[float]:
290 """Rerank candidates via the SDK backend using ``cfg.reranker_model``."""
291 if not candidates:
292 return []
293 self._ensure_initialized()
294 ref = parse_model_ref(cfg.reranker_model)
295 request = RerankRequest(
296 ref=ref,
297 query=query,
298 candidates=candidates,
299 api_base=self._base_url if ref.needs_api_base else None,
300 api_key=self._api_key or None,
301 )
302 try:
303 result = self._backend.rerank(request)
304 except ProviderError:
305 raise
306 except Exception as exc:
307 raise ProviderError(
308 f"Rerank failed: {exc}", provider=self._backend.provider_name
309 ) from exc
310 return result.scores
312 def supports_rerank(self) -> bool:
313 """SDK-backed rerank is available when the underlying SDK is importable."""
314 return self._backend.available()
316 def available(self) -> bool:
317 """Return True when the configured SDK backend can service catalog calls."""
318 return self._backend.available()
320 def shutdown(self) -> None:
321 """SDK-backed providers hold no lilbee-side resources."""
323 def invalidate_load_cache(self, model_path: Path | None = None) -> None:
324 """No-op: cloud backends have no local model cache to evict."""