Coverage for src / lilbee / providers / litellm_sdk.py: 100%
270 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-15 20:55 +0000
1"""litellm implementation of the ``LlmSdkBackend`` Protocol.
3This is the ONLY file in lilbee that imports ``litellm``. When migrating
4to a different SDK (e.g. ``liter-llm``), add a sibling module alongside
5this one and flip the single import in ``providers/factory.py``.
7All knowledge of the litellm wire format (``ollama/`` prefix, OpenAI
8content-parts schema for images) lives here. The semantic layer in
9``sdk_llm_provider`` never touches SDK-specific conventions.
10"""
12from __future__ import annotations
14import base64
15import functools
16import json
17import logging
18from collections.abc import Callable, Iterator
19from typing import Any
21import httpx
23from lilbee.core.config import DEFAULT_HTTP_TIMEOUT
24from lilbee.providers.base import ProviderError
25from lilbee.providers.model_ref import OLLAMA_PREFIX, ProviderModelRef
26from lilbee.providers.sdk_backend import (
27 CompletionRequest,
28 CompletionResult,
29 EmbeddingRequest,
30 EmbeddingResult,
31 RerankRequest,
32 RerankResult,
33 StreamChunk,
34 detect_backend_name,
35)
37log = logging.getLogger(__name__)
39_PROVIDER_NAME = "litellm"
40_OLLAMA_URL_PATTERNS = ("localhost:11434", "127.0.0.1:11434", "ollama")
42# Substrings dropped from the "LiteLLM" logger before they reach the user's
43# terminal. Two classes of noise: (1) the model-cost-map fetch failure that
44# LiteLLM logs at WARNING on every offline chat call, and (2) AWS-flavored
45# advisories from sagemaker / bedrock / boto3 / botocore. lilbee's litellm
46# extra deliberately excludes boto3, so the AWS warnings aren't actionable.
47# Compared case-insensitively to catch the mixed-case variants LiteLLM emits.
48_LITELLM_SUPPRESS_SUBSTRINGS = (
49 "failed to fetch remote model cost map",
50 "boto3",
51 "botocore",
52 "sagemaker",
53 "bedrock",
54)
57class _LitellmSubstringFilter(logging.Filter):
58 """Drop ``LiteLLM`` log records whose message contains a suppressed substring."""
60 def __init__(self, needles: tuple[str, ...]) -> None:
61 super().__init__()
62 self._needles = tuple(n.lower() for n in needles)
64 def filter(self, record: logging.LogRecord) -> bool:
65 msg = record.getMessage().lower()
66 return not any(n in msg for n in self._needles)
69def install_litellm_log_filter() -> None:
70 """Attach the ``LiteLLM`` substring filter to the package logger.
72 Called automatically when this module is imported (see the module-top
73 invocation below) so the filter is in place before any litellm call
74 can emit a warning. Exposed as a function so tests can re-apply after
75 clearing the logger.
76 """
77 logging.getLogger("LiteLLM").addFilter(_LitellmSubstringFilter(_LITELLM_SUPPRESS_SUBSTRINGS))
80# Install the filter at module import. lilbee never touches litellm before
81# importing this module, so installing here always beats litellm's first
82# warning to the punch.
83install_litellm_log_filter()
86class _LitellmResponseView:
87 """Typed read-only view over a litellm completion-response object.
89 The litellm response shape is not in the SDK's type stubs. This
90 adapter is the one place that knows how to pull ``model``, ``choices``,
91 ``message_content`` and the streaming chunk fields out; SDK drift
92 breaks here rather than across every caller.
93 """
95 def __init__(self, response: Any) -> None:
96 self._response = response
98 @property
99 def model(self) -> str | None:
100 """The model name the SDK echoed back, if any."""
101 value = getattr(self._response, "model", None)
102 return str(value) if value is not None else None
104 def _first_choice(self) -> Any:
105 """First entry of the response's ``choices`` list, or ``None``."""
106 choices = getattr(self._response, "choices", None) or []
107 return choices[0] if choices else None
109 @property
110 def message_content(self) -> str:
111 """Content text of the first choice's message (non-stream path)."""
112 choice = self._first_choice()
113 if choice is None:
114 return ""
115 message = getattr(choice, "message", None)
116 if message is None:
117 return ""
118 return getattr(message, "content", "") or ""
120 @property
121 def delta_content(self) -> str:
122 """Content delta of the first choice (stream-path chunk)."""
123 choice = self._first_choice()
124 if choice is None:
125 return ""
126 delta = getattr(choice, "delta", None)
127 if delta is None:
128 return ""
129 return getattr(delta, "content", "") or ""
131 @property
132 def finish_reason(self) -> str | None:
133 """``finish_reason`` of the first choice, if the SDK populated it."""
134 choice = self._first_choice()
135 return getattr(choice, "finish_reason", None) if choice is not None else None
138def _is_ollama(base_url: str) -> bool:
139 """Return True if *base_url* looks like an Ollama instance."""
140 url_lower = base_url.lower()
141 return any(p in url_lower for p in _OLLAMA_URL_PATTERNS)
144@functools.cache
145def litellm_available() -> bool:
146 """Return True if the ``litellm`` package is installed.
148 Uses ``importlib.util.find_spec`` rather than ``import litellm`` so the
149 check stays fast on the UI thread. Executing ``litellm`` on Windows
150 with Defender real-time scanning takes seconds (the package loads a
151 long list of provider plugins on first import); the Settings screen
152 builds synchronously and calls this in ``_FEATURE_GATED_GROUPS``, so
153 a real import here blocks the entire TUI on the first Settings open.
154 ``find_spec`` just walks ``sys.path`` to locate the package; the
155 heavy import runs later, in worker threads or remote-call paths
156 where the cost is expected.
157 """
158 import importlib.util
160 return importlib.util.find_spec("litellm") is not None
163_LITELLM_MISSING_MSG = (
164 "Remote and API models need the lilbee[litellm] extra. "
165 "Reinstall with: uv tool install --prerelease=allow 'lilbee[litellm]'"
166)
169def _require_litellm() -> Any:
170 """Import ``litellm`` or raise a user-facing ProviderError with install steps."""
171 try:
172 import litellm
173 except ImportError as exc:
174 raise ProviderError(_LITELLM_MISSING_MSG, provider=_PROVIDER_NAME) from exc
175 return litellm
178def _cache_ollama_defaults(model: str, params_text: str) -> None:
179 """Parse Ollama parameters and store in the model defaults cache."""
180 from lilbee.providers.model_defaults import parse_kv_parameters, set_defaults
182 defaults = parse_kv_parameters(params_text)
183 set_defaults(model, defaults)
186def _route_model(ref: ProviderModelRef, api_base: str | None) -> str:
187 """Format *ref* for litellm using the OpenAI ``provider/model`` convention."""
188 if ref.is_api:
189 return ref.for_openai_prefix()
190 if api_base and _is_ollama(api_base):
191 return f"{OLLAMA_PREFIX}{ref.name}"
192 return ref.name
195def _format_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
196 """Convert messages with inline image bytes into OpenAI content parts.
198 litellm routes to OpenAI-compatible endpoints that expect the
199 ``{"type": "image_url", "image_url": {...}}`` content-parts schema
200 for multimodal input. Messages without ``images`` pass through
201 untouched.
202 """
203 formatted: list[dict[str, Any]] = []
204 for msg in messages:
205 if "images" in msg:
206 content_parts: list[dict[str, Any]] = [{"type": "text", "text": msg.get("content", "")}]
207 for img in msg["images"]:
208 if isinstance(img, bytes):
209 b64 = base64.b64encode(img).decode()
210 content_parts.append(
211 {
212 "type": "image_url",
213 "image_url": {"url": f"data:image/png;base64,{b64}"},
214 }
215 )
216 formatted.append({"role": msg["role"], "content": content_parts})
217 else:
218 formatted.append(msg)
219 return formatted
222class LitellmSdkBackend:
223 """``LlmSdkBackend`` adapter backed by the ``litellm`` SDK."""
225 @property
226 def provider_name(self) -> str:
227 """Stable identifier used when wrapping errors in ``ProviderError``."""
228 return _PROVIDER_NAME
230 def active_backend_name(self, base_url: str) -> str:
231 """Return the display name of the backend ``base_url`` points at."""
232 return detect_backend_name(base_url)
234 def available(self) -> bool:
235 """Return True if the underlying SDK is installed."""
236 return litellm_available()
238 def configure_logging(self, *, suppress_debug: bool) -> None:
239 """Apply litellm's debug-info suppression toggle when requested."""
240 if not suppress_debug:
241 return
242 try:
243 import litellm
245 litellm.suppress_debug_info = True
246 except ImportError:
247 pass # debug-suppression is best-effort when the litellm extra is absent
249 def complete(self, request: CompletionRequest) -> CompletionResult:
250 """Run a single-shot completion through ``litellm.completion``."""
251 litellm = _require_litellm()
252 kwargs = self._completion_kwargs(request, stream=False)
253 try:
254 response = litellm.completion(**kwargs)
255 except Exception as exc:
256 raise ProviderError(f"Chat failed: {exc}", provider=_PROVIDER_NAME) from exc
257 view = _LitellmResponseView(response)
258 return CompletionResult(
259 content=view.message_content,
260 finish_reason=view.finish_reason,
261 model=view.model,
262 )
264 def complete_stream(self, request: CompletionRequest) -> Iterator[StreamChunk]:
265 """Stream a completion through ``litellm.completion(stream=True)``."""
266 litellm = _require_litellm()
267 kwargs = self._completion_kwargs(request, stream=True)
268 try:
269 response = litellm.completion(**kwargs)
270 except Exception as exc:
271 raise ProviderError(f"Chat failed: {exc}", provider=_PROVIDER_NAME) from exc
272 return self._stream_chunks(response)
274 @staticmethod
275 def _stream_chunks(response: Any) -> Iterator[StreamChunk]:
276 """Yield ``StreamChunk`` values from a litellm streaming response.
278 Exceptions raised mid-iteration are wrapped in ``ProviderError``
279 so the semantic layer sees a consistent error type regardless of
280 where the SDK failed.
281 """
282 try:
283 for chunk in response:
284 view = _LitellmResponseView(chunk)
285 content = view.delta_content
286 finish_reason = view.finish_reason
287 if content or finish_reason:
288 yield StreamChunk(content=content, finish_reason=finish_reason)
289 except ProviderError:
290 raise
291 except Exception as exc:
292 raise ProviderError(f"Chat failed: {exc}", provider=_PROVIDER_NAME) from exc
294 @staticmethod
295 def _completion_kwargs(request: CompletionRequest, *, stream: bool) -> dict[str, Any]:
296 """Translate a ``CompletionRequest`` into litellm kwargs."""
297 kwargs: dict[str, Any] = {
298 "model": _route_model(request.ref, request.api_base),
299 "messages": _format_messages(request.messages),
300 "stream": stream,
301 }
302 if request.api_base:
303 kwargs["api_base"] = request.api_base
304 if request.api_key:
305 kwargs["api_key"] = request.api_key
306 if request.options:
307 kwargs.update(request.options)
308 return kwargs
310 def embed(self, request: EmbeddingRequest) -> EmbeddingResult:
311 """Embed inputs through ``litellm.embedding``."""
312 litellm = _require_litellm()
313 kwargs: dict[str, Any] = {
314 "model": _route_model(request.ref, request.api_base),
315 "input": request.inputs,
316 }
317 if request.api_base:
318 kwargs["api_base"] = request.api_base
319 if request.api_key:
320 kwargs["api_key"] = request.api_key
321 try:
322 response = litellm.embedding(**kwargs)
323 except Exception as exc:
324 raise ProviderError(f"Embedding failed: {exc}", provider=_PROVIDER_NAME) from exc
325 data = response["data"] if isinstance(response, dict) else response.data
326 vectors = [item["embedding"] for item in data]
327 if isinstance(response, dict):
328 model = response.get("model")
329 else:
330 model = getattr(response, "model", None)
331 return EmbeddingResult(vectors=vectors, model=model)
333 def rerank(self, request: RerankRequest) -> RerankResult:
334 """Rerank documents via ``litellm.rerank`` (Cohere, Voyage, Jina, Together, HF TEI).
336 The SDK returns results sorted by relevance; we restore input
337 order via each result's ``index`` so scores line up with the
338 caller's ``candidates`` list.
339 """
340 if not request.candidates:
341 return RerankResult(scores=[])
342 litellm = _require_litellm()
343 kwargs: dict[str, Any] = {
344 "model": _route_model(request.ref, request.api_base),
345 "query": request.query,
346 "documents": request.candidates,
347 }
348 if request.api_base:
349 kwargs["api_base"] = request.api_base
350 if request.api_key:
351 kwargs["api_key"] = request.api_key
352 try:
353 response = litellm.rerank(**kwargs)
354 except Exception as exc:
355 raise ProviderError(f"Rerank failed: {exc}", provider=_PROVIDER_NAME) from exc
356 results = response["results"] if isinstance(response, dict) else response.results
357 scores = [0.0] * len(request.candidates)
358 for item in results:
359 idx = item["index"] if isinstance(item, dict) else item.index
360 score = item["relevance_score"] if isinstance(item, dict) else item.relevance_score
361 scores[idx] = float(score)
362 if isinstance(response, dict):
363 model = response.get("model")
364 else:
365 model = getattr(response, "model", None)
366 return RerankResult(scores=scores, model=model)
368 def list_models(self, *, base_url: str, api_key: str) -> list[str]:
369 """List models from Ollama or an OpenAI-compatible server."""
370 clean_base = base_url.rstrip("/")
371 if _is_ollama(clean_base):
372 return self._list_ollama_models(clean_base)
373 return self._list_openai_models(clean_base, api_key)
375 def list_chat_models(self, provider: str) -> list[str]:
376 """Return chat-mode model ids from litellm's static catalog.
378 Returns whatever litellm exposes for *provider*, alphabetically.
379 Empty list when litellm is not installed or the provider has no
380 chat-mode entries.
381 """
382 try:
383 import litellm
384 except ImportError:
385 return []
386 return self._all_chat_models_for(provider, litellm)
388 @staticmethod
389 def _all_chat_models_for(provider: str, litellm: Any) -> list[str]:
390 """Filter litellm's catalog down to chat-mode entries for ``provider``.
392 litellm's catalog stores some providers' models bare (``gpt-4o``)
393 and others prefixed (``mistral/codestral-latest``,
394 ``openrouter/anthropic/claude-3.5-sonnet``). Strip any leading
395 ``{provider}/`` so callers see uniformly bare names; the canonical
396 ``provider/name`` form is reapplied at the routing layer via
397 :meth:`ProviderModelRef.for_openai_prefix`.
398 """
399 models = litellm.models_by_provider.get(provider, set())
400 prefix = f"{provider}/"
401 bare: set[str] = set()
402 for model_name in models:
403 info = litellm.model_cost.get(model_name, {})
404 if info.get("mode") != "chat":
405 continue
406 bare.add(model_name.removeprefix(prefix))
407 return sorted(bare)
409 @staticmethod
410 def _list_ollama_models(base_url: str) -> list[str]:
411 """List models via the Ollama ``/api/tags`` endpoint."""
412 try:
413 resp = httpx.get(f"{base_url}/api/tags", timeout=DEFAULT_HTTP_TIMEOUT)
414 resp.raise_for_status()
415 data = resp.json()
416 return [m["name"] for m in data.get("models", [])]
417 except httpx.HTTPError as exc:
418 raise ProviderError(f"Cannot list models: {exc}", provider=_PROVIDER_NAME) from exc
420 @staticmethod
421 def _list_openai_models(base_url: str, api_key: str) -> list[str]:
422 """List models via an OpenAI-compatible ``/v1/models`` endpoint."""
423 headers: dict[str, str] = {}
424 if api_key:
425 headers["Authorization"] = f"Bearer {api_key}"
426 try:
427 resp = httpx.get(f"{base_url}/v1/models", headers=headers, timeout=DEFAULT_HTTP_TIMEOUT)
428 resp.raise_for_status()
429 data = resp.json()
430 return [m["id"] for m in data.get("data", [])]
431 except httpx.HTTPError:
432 log.debug("Failed to list models via /v1/models", exc_info=True)
433 return []
435 def pull_model(
436 self,
437 model: str,
438 *,
439 base_url: str,
440 on_progress: Callable[..., Any] | None = None,
441 ) -> None:
442 """Pull a model via the Ollama ``/api/pull`` endpoint."""
443 clean_base = base_url.rstrip("/")
444 try:
445 with (
446 # Streaming Ollama /api/pull; unbounded read is intentional
447 # since model downloads can exceed any wall-clock timeout.
448 httpx.Client(timeout=None) as client, # noqa: S113
449 client.stream(
450 "POST",
451 f"{clean_base}/api/pull",
452 json={"name": model, "stream": True},
453 ) as resp,
454 ):
455 resp.raise_for_status()
456 for line in resp.iter_lines():
457 if not line:
458 continue
459 event = json.loads(line)
460 if on_progress:
461 on_progress(event)
462 if event.get("status") == "success":
463 break
464 except httpx.HTTPError as exc:
465 raise ProviderError(
466 f"Cannot pull model {model!r}: {exc}", provider=_PROVIDER_NAME
467 ) from exc
469 def show_model(self, model: str, *, base_url: str) -> dict[str, Any] | None:
470 """Get model info via the Ollama ``/api/show`` endpoint.
472 Parses and caches per-model generation defaults from the
473 ``parameters`` field. Also extracts the ``capabilities`` list
474 (newer Ollama versions) so callers can check for vision support.
475 """
476 clean_base = base_url.rstrip("/")
477 # Ollama's API uses bare model names; the routing-layer prefix has
478 # to come off before the request goes out.
479 ollama_name = model[len(OLLAMA_PREFIX) :] if model.startswith(OLLAMA_PREFIX) else model
480 try:
481 resp = httpx.post(
482 f"{clean_base}/api/show",
483 json={"name": ollama_name},
484 timeout=DEFAULT_HTTP_TIMEOUT,
485 )
486 resp.raise_for_status()
487 data = resp.json()
488 except httpx.HTTPError:
489 return None
491 result: dict[str, Any] = {}
493 params = data.get("parameters", "")
494 if isinstance(params, str) and params:
495 _cache_ollama_defaults(model, params)
496 result["parameters"] = params
497 elif params:
498 _cache_ollama_defaults(model, str(params))
499 result["parameters"] = str(params)
501 capabilities = data.get("capabilities")
502 if isinstance(capabilities, list):
503 result["capabilities"] = capabilities
505 return result or None