Coverage for src / lilbee / providers / sdk_llm_provider.py: 100%
159 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""SDK-agnostic LLM provider implementing the public ``LLMProvider`` Protocol.
3``SdkLLMProvider`` owns the semantic layer: auth key injection, option
4translation, model-ref parsing, error wrapping, and lazy one-shot
5backend initialization (``configure_logging`` + ``inject_provider_keys``
6on first use). It speaks to the underlying SDK exclusively through an
7``LlmSdkBackend``, so swapping SDKs is a one-file adapter change.
9Zero direct SDK imports live here. The adapter owns SDK-specific
10concerns like wire-format prefixes (``ollama/``) and OpenAI content-parts
11schema for image inputs.
12"""
14from __future__ import annotations
16import logging
17import os
18from collections.abc import Callable
19from pathlib import Path
20from typing import Any, Literal, overload
22from lilbee.core.config import cfg
23from lilbee.providers.base import ClosableIterator, LLMProvider, ProviderError
24from lilbee.providers.local_servers import LOCAL_SERVER_KEYS
25from lilbee.providers.local_servers.config_urls import base_url_for, configured_local_servers
26from lilbee.providers.model_ref import ProviderModelRef, parse_model_ref, translate_options
27from lilbee.providers.sdk_backend import (
28 PROVIDER_KEYS,
29 CompletionRequest,
30 EmbeddingRequest,
31 LlmSdkBackend,
32 RerankRequest,
33)
34from lilbee.providers.worker.transport import OcrBackend
35from lilbee.vision import PageText
37log = logging.getLogger(__name__)
40def _api_base_for(ref: ProviderModelRef) -> str | None:
41 """Endpoint for a local-server ref; ``None`` for hosted APIs (no base needed)."""
42 if ref.provider in LOCAL_SERVER_KEYS:
43 return base_url_for(ref.provider)
44 return None
47def inject_provider_keys() -> None:
48 """Copy per-provider API keys from config into ``os.environ``.
50 OpenAI-compatible SDKs read provider-specific env vars
51 (``OPENAI_API_KEY``, ``ANTHROPIC_API_KEY``, ...) at call time. This
52 bridges lilbee's config system to that convention. Explicit env
53 vars are never overwritten so users can still override via their
54 shell.
55 """
56 for _, cfg_field, env_var, _ in PROVIDER_KEYS:
57 value = getattr(cfg, cfg_field, "")
58 if value and not os.environ.get(env_var):
59 os.environ[env_var] = value
62class SdkLLMProvider(LLMProvider):
63 """Provider that delegates SDK calls to an ``LlmSdkBackend``."""
65 def __init__(
66 self,
67 backend: LlmSdkBackend,
68 *,
69 api_key: str = "",
70 ) -> None:
71 self._backend = backend
72 self._api_key = api_key
73 self._initialized = False
75 def _ensure_initialized(self) -> None:
76 """Apply one-shot backend setup before the first call.
78 Runs ``configure_logging(suppress_debug=cfg.json_mode)`` and
79 ``inject_provider_keys()`` exactly once, regardless of whether
80 the first operation is ``chat``, ``embed``, or a catalog query.
81 Both steps happen together because the backend's first SDK
82 import must see (a) the debug flag applied, and (b) per-provider
83 API keys in ``os.environ``.
84 """
85 if self._initialized:
86 return
87 try:
88 self._backend.configure_logging(suppress_debug=cfg.json_mode)
89 except (ImportError, AttributeError):
90 log.debug("backend.configure_logging failed", exc_info=True)
91 inject_provider_keys()
92 self._initialized = True
94 def embed(self, texts: list[str]) -> list[list[float]]:
95 """Embed texts via the configured backend."""
96 self._ensure_initialized()
97 ref = parse_model_ref(cfg.embedding_model)
98 request = EmbeddingRequest(
99 ref=ref,
100 inputs=texts,
101 api_base=_api_base_for(ref),
102 api_key=self._api_key or None,
103 )
104 try:
105 result = self._backend.embed(request)
106 except ProviderError:
107 raise
108 except Exception as exc:
109 raise ProviderError(
110 f"Embedding failed: {exc}", provider=self._backend.provider_name
111 ) from exc
112 return result.vectors
114 @overload
115 def chat(
116 self,
117 messages: list[dict[str, str]],
118 *,
119 stream: Literal[False] = False,
120 options: dict[str, Any] | None = None,
121 model: str | None = None,
122 ) -> str: ...
124 @overload
125 def chat(
126 self,
127 messages: list[dict[str, str]],
128 *,
129 stream: Literal[True],
130 options: dict[str, Any] | None = None,
131 model: str | None = None,
132 ) -> ClosableIterator[str]: ...
134 def chat(
135 self,
136 messages: list[dict[str, str]],
137 *,
138 stream: bool = False,
139 options: dict[str, Any] | None = None,
140 model: str | None = None,
141 ) -> str | ClosableIterator[str]:
142 """Chat completion via the configured backend."""
143 self._ensure_initialized()
144 ref = parse_model_ref(model or cfg.chat_model)
145 translated = translate_options(options, ref) if options else {}
146 request = CompletionRequest(
147 ref=ref,
148 messages=list(messages),
149 options=translated,
150 api_base=_api_base_for(ref),
151 api_key=self._api_key or None,
152 )
153 if stream:
154 return self._chat_stream(request)
155 try:
156 result = self._backend.complete(request)
157 except ProviderError:
158 raise
159 except Exception as exc:
160 raise ProviderError(
161 f"Chat failed: {exc}", provider=self._backend.provider_name
162 ) from exc
163 return result.content
165 def _chat_stream(self, request: CompletionRequest) -> ClosableIterator[str]:
166 """Yield content tokens from a streaming completion.
168 Exceptions surfaced by the backend at either call time or during
169 iteration are re-raised as ``ProviderError`` so callers always
170 see a consistent error type.
171 """
172 try:
173 stream = self._backend.complete_stream(request)
174 for chunk in stream:
175 if chunk.content:
176 yield chunk.content
177 except ProviderError:
178 raise
179 except Exception as exc:
180 raise ProviderError(
181 f"Chat failed: {exc}", provider=self._backend.provider_name
182 ) from exc
184 def vision_ocr(
185 self,
186 png_bytes: bytes,
187 model: str,
188 prompt: str = "",
189 *,
190 timeout: float | None = None,
191 ) -> str:
192 """OCR via a multipart chat completion; ``timeout`` enforced via thread pool."""
193 from lilbee.vision import OCR_PROMPT, build_vision_messages
195 messages = build_vision_messages(prompt or OCR_PROMPT, png_bytes)
196 if timeout and timeout > 0:
197 from concurrent.futures import ThreadPoolExecutor
199 with ThreadPoolExecutor(max_workers=1) as pool:
200 future = pool.submit(self.chat, messages, stream=False, model=model)
201 result = future.result(timeout=timeout)
202 else:
203 result = self.chat(messages, stream=False, model=model)
204 if not isinstance(result, str):
205 raise ProviderError(
206 f"Vision OCR returned non-text response ({type(result).__name__}).",
207 provider=self._backend.provider_name,
208 )
209 return result
211 def pdf_ocr(
212 self,
213 path: Path,
214 *,
215 backend: OcrBackend,
216 model: str = "",
217 per_page_timeout_s: float | None = None,
218 quiet: bool = True,
219 on_progress: Callable[..., None] | None = None,
220 ) -> list[PageText]:
221 """SDK backend cannot rasterise PDFs locally; ingest callers fall back."""
222 del path, backend, model, per_page_timeout_s, quiet, on_progress
223 raise NotImplementedError(
224 "Hosted models do not support scanned-PDF OCR. "
225 "Set LILBEE_VISION_MODEL to a local GGUF vision model to enable it."
226 )
228 def list_models(self) -> list[str]:
229 """List models across every configured local server (empty list on SDK errors)."""
230 names: list[str] = []
231 for _spec, base_url in configured_local_servers():
232 try:
233 names.extend(self._backend.list_models(base_url=base_url, api_key=self._api_key))
234 except NotImplementedError:
235 continue
236 except ProviderError:
237 raise
238 except Exception as exc:
239 raise ProviderError(
240 f"Listing models failed: {exc}", provider=self._backend.provider_name
241 ) from exc
242 return names
244 def list_chat_models(self, provider: str) -> list[str]:
245 """List frontier chat models known to the backend for *provider*.
247 Initializes the backend first so ``cfg.json_mode`` suppression is
248 applied before the SDK import inside the backend runs.
249 """
250 self._ensure_initialized()
251 try:
252 return self._backend.list_chat_models(provider)
253 except NotImplementedError:
254 return []
255 except ProviderError:
256 raise
257 except Exception as exc:
258 raise ProviderError(
259 f"Listing chat models failed: {exc}", provider=self._backend.provider_name
260 ) from exc
262 def pull_model(self, model: str, *, on_progress: Callable[..., Any] | None = None) -> None:
263 """Pull a model via the backend."""
264 try:
265 base_url = _api_base_for(parse_model_ref(model)) or ""
266 self._backend.pull_model(model, base_url=base_url, on_progress=on_progress)
267 except NotImplementedError as exc:
268 raise ProviderError(
269 f"Cannot pull model {model!r}: backend does not support pulling",
270 provider=self._backend.provider_name,
271 ) from exc
272 except ProviderError:
273 raise
274 except Exception as exc:
275 raise ProviderError(
276 f"Cannot pull model {model!r}: {exc}", provider=self._backend.provider_name
277 ) from exc
279 def show_model(self, model: str) -> dict[str, Any] | None:
280 """Return model metadata, or None when unsupported or not found."""
281 try:
282 base_url = _api_base_for(parse_model_ref(model)) or ""
283 return self._backend.show_model(model, base_url=base_url)
284 except NotImplementedError:
285 return None
286 except ProviderError:
287 raise
288 except Exception as exc:
289 raise ProviderError(
290 f"Showing model {model!r} failed: {exc}", provider=self._backend.provider_name
291 ) from exc
293 def get_capabilities(self, model: str) -> list[str]:
294 """Return capability tags from ``show_model`` output, or ``[]``."""
295 info = self.show_model(model)
296 if info is None:
297 return []
298 caps = info.get("capabilities", [])
299 return caps if isinstance(caps, list) else []
301 def rerank(self, query: str, candidates: list[str]) -> list[float]:
302 """Rerank candidates via the SDK backend using ``cfg.reranker_model``."""
303 if not candidates:
304 return []
305 self._ensure_initialized()
306 ref = parse_model_ref(cfg.reranker_model)
307 request = RerankRequest(
308 ref=ref,
309 query=query,
310 candidates=candidates,
311 api_base=_api_base_for(ref),
312 api_key=self._api_key or None,
313 )
314 try:
315 result = self._backend.rerank(request)
316 except ProviderError:
317 raise
318 except Exception as exc:
319 raise ProviderError(
320 f"Rerank failed: {exc}", provider=self._backend.provider_name
321 ) from exc
322 return result.scores
324 def supports_rerank(self) -> bool:
325 """SDK-backed rerank is available when the underlying SDK is importable."""
326 return self._backend.available()
328 def available(self) -> bool:
329 """Return True when the configured SDK backend can service catalog calls."""
330 return self._backend.available()
332 def shutdown(self) -> None:
333 """SDK-backed providers hold no lilbee-side resources."""
335 def invalidate_load_cache(self, model_path: Path | None = None) -> None:
336 """No-op: cloud backends have no local model cache to evict."""