Coverage for src / lilbee / providers / litellm_sdk.py: 100%
292 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""litellm implementation of the ``LlmSdkBackend`` Protocol.
3This is the ONLY file in lilbee that imports ``litellm``. When migrating
4to a different SDK (e.g. ``liter-llm``), add a sibling module alongside
5this one and flip the single import in ``providers/factory.py``.
7All knowledge of the litellm wire format (``ollama/`` prefix, OpenAI
8content-parts schema for images) lives here. The semantic layer in
9``sdk_llm_provider`` never touches SDK-specific conventions.
10"""
12from __future__ import annotations
14import base64
15import functools
16import logging
17from collections.abc import Callable, Iterator
18from typing import Any
20import httpx
22from lilbee.core.config import DEFAULT_HTTP_TIMEOUT
23from lilbee.providers.base import ProviderError, ProviderErrorKind
24from lilbee.providers.local_servers import (
25 OLLAMA,
26 detect_local_server,
27 local_server_for_key,
28 openai_models_url,
29)
30from lilbee.providers.model_ref import ProviderModelRef
31from lilbee.providers.sdk_backend import (
32 CompletionRequest,
33 CompletionResult,
34 EmbeddingRequest,
35 EmbeddingResult,
36 RerankRequest,
37 RerankResult,
38 StreamChunk,
39 detect_backend_name,
40)
42log = logging.getLogger(__name__)
44_PROVIDER_NAME = "litellm"
46# Substrings dropped from the "LiteLLM" logger before they reach the user's
47# terminal. Two classes of noise: (1) the model-cost-map fetch failure that
48# LiteLLM logs at WARNING on every offline chat call, and (2) AWS-flavored
49# advisories from sagemaker / bedrock / boto3 / botocore. lilbee's litellm
50# extra deliberately excludes boto3, so the AWS warnings aren't actionable.
51# Compared case-insensitively to catch the mixed-case variants LiteLLM emits.
52_LITELLM_SUPPRESS_SUBSTRINGS = (
53 "failed to fetch remote model cost map",
54 "boto3",
55 "botocore",
56 "sagemaker",
57 "bedrock",
58)
61class _LitellmSubstringFilter(logging.Filter):
62 """Drop ``LiteLLM`` log records whose message contains a suppressed substring."""
64 def __init__(self, needles: tuple[str, ...]) -> None:
65 super().__init__()
66 self._needles = tuple(n.lower() for n in needles)
68 def filter(self, record: logging.LogRecord) -> bool:
69 msg = record.getMessage().lower()
70 return not any(n in msg for n in self._needles)
73def install_litellm_log_filter() -> None:
74 """Attach the ``LiteLLM`` substring filter to the package logger.
76 Called automatically when this module is imported (see the module-top
77 invocation below) so the filter is in place before any litellm call
78 can emit a warning. Exposed as a function so tests can re-apply after
79 clearing the logger.
80 """
81 logging.getLogger("LiteLLM").addFilter(_LitellmSubstringFilter(_LITELLM_SUPPRESS_SUBSTRINGS))
84# Install the filter at module import. lilbee never touches litellm before
85# importing this module, so installing here always beats litellm's first
86# warning to the punch.
87install_litellm_log_filter()
90class _LitellmResponseView:
91 """Typed read-only view over a litellm completion-response object.
93 The litellm response shape is not in the SDK's type stubs. This
94 adapter is the one place that knows how to pull ``model``, ``choices``,
95 ``message_content`` and the streaming chunk fields out; SDK drift
96 breaks here rather than across every caller.
97 """
99 def __init__(self, response: Any) -> None:
100 self._response = response
102 @property
103 def model(self) -> str | None:
104 """The model name the SDK echoed back, if any."""
105 value = getattr(self._response, "model", None)
106 return str(value) if value is not None else None
108 def _first_choice(self) -> Any:
109 """First entry of the response's ``choices`` list, or ``None``."""
110 choices = getattr(self._response, "choices", None) or []
111 return choices[0] if choices else None
113 @property
114 def message_content(self) -> str:
115 """Content text of the first choice's message (non-stream path)."""
116 choice = self._first_choice()
117 if choice is None:
118 return ""
119 message = getattr(choice, "message", None)
120 if message is None:
121 return ""
122 return getattr(message, "content", "") or ""
124 @property
125 def delta_content(self) -> str:
126 """Content delta of the first choice (stream-path chunk)."""
127 choice = self._first_choice()
128 if choice is None:
129 return ""
130 delta = getattr(choice, "delta", None)
131 if delta is None:
132 return ""
133 return getattr(delta, "content", "") or ""
135 @property
136 def finish_reason(self) -> str | None:
137 """``finish_reason`` of the first choice, if the SDK populated it."""
138 choice = self._first_choice()
139 return getattr(choice, "finish_reason", None) if choice is not None else None
142@functools.cache
143def litellm_available() -> bool:
144 """Return True if the ``litellm`` package is installed.
146 Uses ``importlib.util.find_spec`` rather than ``import litellm`` so the
147 check stays fast on the UI thread. Executing ``litellm`` on Windows
148 with Defender real-time scanning takes seconds (the package loads a
149 long list of provider plugins on first import); the Settings screen
150 builds synchronously and calls this in ``_FEATURE_GATED_GROUPS``, so
151 a real import here blocks the entire TUI on the first Settings open.
152 ``find_spec`` just walks ``sys.path`` to locate the package; the
153 heavy import runs later, in worker threads or remote-call paths
154 where the cost is expected.
155 """
156 import importlib.util
158 return importlib.util.find_spec("litellm") is not None
161_LITELLM_MISSING_MSG = (
162 "Remote and API models need the lilbee[litellm] extra. "
163 "Reinstall with: uv tool install --prerelease=allow 'lilbee[litellm]'"
164)
167def _require_litellm() -> Any:
168 """Import ``litellm`` or raise a user-facing ProviderError with install steps."""
169 try:
170 import litellm
171 except ImportError as exc:
172 raise ProviderError(_LITELLM_MISSING_MSG, provider=_PROVIDER_NAME) from exc
173 return litellm
176def _cache_ollama_defaults(model: str, params_text: str) -> None:
177 """Parse Ollama parameters and store in the model defaults cache."""
178 from lilbee.providers.model_defaults import parse_kv_parameters, set_defaults
180 defaults = parse_kv_parameters(params_text)
181 set_defaults(model, defaults)
184def _route_model(ref: ProviderModelRef, api_base: str | None) -> str:
185 """Format *ref* for litellm using the OpenAI ``provider/model`` convention.
187 API and local-server refs already carry their canonical prefix. A bare
188 ``local`` ref forced through the SDK (``llm_provider=remote``) gets the
189 prefix of whichever local server its ``api_base`` points at.
190 """
191 if ref.is_api or local_server_for_key(ref.provider) is not None:
192 return ref.for_openai_prefix()
193 if api_base and (spec := detect_local_server(api_base)) is not None:
194 return spec.qualify(ref.name)
195 return ref.name
198def _format_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
199 """Convert messages with inline image bytes into OpenAI content parts.
201 litellm routes to OpenAI-compatible endpoints that expect the
202 ``{"type": "image_url", "image_url": {...}}`` content-parts schema
203 for multimodal input. Messages without ``images`` pass through
204 untouched.
205 """
206 formatted: list[dict[str, Any]] = []
207 for msg in messages:
208 if "images" in msg:
209 content_parts: list[dict[str, Any]] = [{"type": "text", "text": msg.get("content", "")}]
210 for img in msg["images"]:
211 if isinstance(img, bytes):
212 b64 = base64.b64encode(img).decode()
213 content_parts.append(
214 {
215 "type": "image_url",
216 "image_url": {"url": f"data:image/png;base64,{b64}"},
217 }
218 )
219 formatted.append({"role": msg["role"], "content": content_parts})
220 else:
221 formatted.append(msg)
222 return formatted
225# User-facing message per recognised error kind. Each names the problem against
226# {model} and makes clear the cause sits with the user's provider account or
227# network, not with lilbee. UNKNOWN has no entry and falls back to the raw error.
228_KIND_MESSAGES: dict[ProviderErrorKind, str] = {
229 ProviderErrorKind.RATE_LIMIT: (
230 "{model} is rate-limited or out of quota. That's a limit on your provider "
231 "API key, not a lilbee problem. Check your plan and billing with the "
232 "provider, or pick a different model."
233 ),
234 ProviderErrorKind.AUTH: (
235 "{model} rejected your API key. Check that the key is set correctly and has "
236 "access to this model. That's between your key and the provider, not a lilbee problem."
237 ),
238 ProviderErrorKind.NOT_FOUND: (
239 "The provider doesn't offer {model} on your account. "
240 "Pick a different model or check the name."
241 ),
242 ProviderErrorKind.CONTEXT_OVERFLOW: (
243 "This conversation is too long for {model}'s context window. "
244 "Start a new chat or pick a model with a larger context."
245 ),
246 ProviderErrorKind.BAD_REQUEST: (
247 "The provider rejected the request for {model}. Check the model name and your settings."
248 ),
249 ProviderErrorKind.CONNECTION: (
250 "Couldn't reach the provider for {model}, or it timed out. Check your "
251 "connection and base URL, then try again or pick a different model."
252 ),
253 ProviderErrorKind.SERVER: (
254 "The provider for {model} is unavailable right now. That's on the provider's "
255 "side, not a lilbee problem. Try again shortly or pick a different model."
256 ),
257}
259# Operation labels prefixed onto the fallback message for an unrecognised error.
260_CHAT_FAILED = "Chat failed"
261_EMBED_FAILED = "Embedding failed"
262_RERANK_FAILED = "Rerank failed"
265def _cause_chain(exc: BaseException) -> list[BaseException]:
266 """Return *exc* and its causes, root cause first.
268 litellm's mid-stream fallback keeps the real cause in ``original_exception``;
269 walking root-first stops a 503 wrapper from masking the 429 it carries.
270 """
271 chain: list[BaseException] = []
272 seen: set[int] = set()
273 cur: BaseException | None = exc
274 while cur is not None and id(cur) not in seen:
275 seen.add(id(cur))
276 chain.append(cur)
277 nxt = getattr(cur, "original_exception", None)
278 if not isinstance(nxt, BaseException):
279 nxt = cur.__cause__
280 cur = nxt if isinstance(nxt, BaseException) else None
281 chain.reverse()
282 return chain
285def _classify_litellm_error(exc: BaseException) -> ProviderErrorKind:
286 """Map a litellm exception to a ``ProviderErrorKind`` by type, never by message.
288 litellm normalises every backend's failures into one exception hierarchy, so
289 the same mapping covers all providers. The MRO walk picks the most specific
290 kind (``ContextWindowExceededError`` over its ``BadRequestError`` base).
291 """
292 try:
293 import litellm
294 except ImportError: # pragma: no cover - unreachable after a real litellm call
295 return ProviderErrorKind.UNKNOWN
296 table: dict[type, ProviderErrorKind] = {
297 litellm.AuthenticationError: ProviderErrorKind.AUTH,
298 litellm.PermissionDeniedError: ProviderErrorKind.AUTH,
299 litellm.NotFoundError: ProviderErrorKind.NOT_FOUND,
300 litellm.RateLimitError: ProviderErrorKind.RATE_LIMIT,
301 litellm.ContextWindowExceededError: ProviderErrorKind.CONTEXT_OVERFLOW,
302 litellm.BadRequestError: ProviderErrorKind.BAD_REQUEST,
303 litellm.Timeout: ProviderErrorKind.CONNECTION,
304 litellm.APIConnectionError: ProviderErrorKind.CONNECTION,
305 litellm.ServiceUnavailableError: ProviderErrorKind.SERVER,
306 litellm.InternalServerError: ProviderErrorKind.SERVER,
307 }
308 for err in _cause_chain(exc):
309 for cls in type(err).__mro__:
310 kind = table.get(cls)
311 if kind is not None:
312 return kind
313 return ProviderErrorKind.UNKNOWN
316def _provider_error(fallback: str, exc: Exception, model: str) -> ProviderError:
317 """Wrap a litellm failure as a ``ProviderError`` classified by type.
319 Recognised kinds get a blob-free, user-facing message; unrecognised ones
320 keep the raw ``{fallback}: {exc}`` shape so nothing is lost when debugging.
321 """
322 kind = _classify_litellm_error(exc)
323 template = _KIND_MESSAGES.get(kind)
324 message = template.format(model=model) if template is not None else f"{fallback}: {exc}"
325 return ProviderError(message, provider=_PROVIDER_NAME, kind=kind)
328class LitellmSdkBackend:
329 """``LlmSdkBackend`` adapter backed by the ``litellm`` SDK."""
331 @property
332 def provider_name(self) -> str:
333 """Stable identifier used when wrapping errors in ``ProviderError``."""
334 return _PROVIDER_NAME
336 def active_backend_name(self, base_url: str) -> str:
337 """Return the display name of the backend ``base_url`` points at."""
338 return detect_backend_name(base_url)
340 def available(self) -> bool:
341 """Return True if the underlying SDK is installed."""
342 return litellm_available()
344 def configure_logging(self, *, suppress_debug: bool) -> None:
345 """Apply litellm's debug-info suppression toggle when requested."""
346 if not suppress_debug:
347 return
348 try:
349 import litellm
351 litellm.suppress_debug_info = True
352 except ImportError:
353 pass # debug-suppression is best-effort when the litellm extra is absent
355 def complete(self, request: CompletionRequest) -> CompletionResult:
356 """Run a single-shot completion through ``litellm.completion``."""
357 litellm = _require_litellm()
358 kwargs = self._completion_kwargs(request, stream=False)
359 try:
360 response = litellm.completion(**kwargs)
361 except Exception as exc:
362 raise _provider_error(_CHAT_FAILED, exc, request.ref.for_display()) from exc
363 view = _LitellmResponseView(response)
364 return CompletionResult(
365 content=view.message_content,
366 finish_reason=view.finish_reason,
367 model=view.model,
368 )
370 def complete_stream(self, request: CompletionRequest) -> Iterator[StreamChunk]:
371 """Stream a completion through ``litellm.completion(stream=True)``."""
372 litellm = _require_litellm()
373 kwargs = self._completion_kwargs(request, stream=True)
374 model = request.ref.for_display()
375 try:
376 response = litellm.completion(**kwargs)
377 except Exception as exc:
378 raise _provider_error(_CHAT_FAILED, exc, model) from exc
379 return self._stream_chunks(response, model)
381 @staticmethod
382 def _stream_chunks(response: Any, model: str) -> Iterator[StreamChunk]:
383 """Yield ``StreamChunk`` values from a litellm streaming response.
385 Exceptions raised mid-iteration are classified into ``ProviderError``
386 so the semantic layer sees a consistent error type regardless of
387 where the SDK failed.
388 """
389 try:
390 for chunk in response:
391 view = _LitellmResponseView(chunk)
392 content = view.delta_content
393 finish_reason = view.finish_reason
394 if content or finish_reason:
395 yield StreamChunk(content=content, finish_reason=finish_reason)
396 except ProviderError:
397 raise
398 except Exception as exc:
399 raise _provider_error(_CHAT_FAILED, exc, model) from exc
401 @staticmethod
402 def _completion_kwargs(request: CompletionRequest, *, stream: bool) -> dict[str, Any]:
403 """Translate a ``CompletionRequest`` into litellm kwargs."""
404 kwargs: dict[str, Any] = {
405 "model": _route_model(request.ref, request.api_base),
406 "messages": _format_messages(request.messages),
407 "stream": stream,
408 }
409 if request.api_base:
410 kwargs["api_base"] = request.api_base
411 if request.api_key:
412 kwargs["api_key"] = request.api_key
413 if request.options:
414 kwargs.update(request.options)
415 return kwargs
417 def embed(self, request: EmbeddingRequest) -> EmbeddingResult:
418 """Embed inputs through ``litellm.embedding``."""
419 litellm = _require_litellm()
420 kwargs: dict[str, Any] = {
421 "model": _route_model(request.ref, request.api_base),
422 "input": request.inputs,
423 }
424 if request.api_base:
425 kwargs["api_base"] = request.api_base
426 if request.api_key:
427 kwargs["api_key"] = request.api_key
428 try:
429 response = litellm.embedding(**kwargs)
430 except Exception as exc:
431 raise _provider_error(_EMBED_FAILED, exc, request.ref.for_display()) from exc
432 data = response["data"] if isinstance(response, dict) else response.data
433 vectors = [item["embedding"] for item in data]
434 if isinstance(response, dict):
435 model = response.get("model")
436 else:
437 model = getattr(response, "model", None)
438 return EmbeddingResult(vectors=vectors, model=model)
440 def rerank(self, request: RerankRequest) -> RerankResult:
441 """Rerank documents via ``litellm.rerank`` (Cohere, Voyage, Jina, Together, HF TEI).
443 The SDK returns results sorted by relevance; we restore input
444 order via each result's ``index`` so scores line up with the
445 caller's ``candidates`` list.
446 """
447 if not request.candidates:
448 return RerankResult(scores=[])
449 litellm = _require_litellm()
450 kwargs: dict[str, Any] = {
451 "model": _route_model(request.ref, request.api_base),
452 "query": request.query,
453 "documents": request.candidates,
454 }
455 if request.api_base:
456 kwargs["api_base"] = request.api_base
457 if request.api_key:
458 kwargs["api_key"] = request.api_key
459 try:
460 response = litellm.rerank(**kwargs)
461 except Exception as exc:
462 raise _provider_error(_RERANK_FAILED, exc, request.ref.for_display()) from exc
463 results = response["results"] if isinstance(response, dict) else response.results
464 scores = [0.0] * len(request.candidates)
465 for item in results:
466 idx = item["index"] if isinstance(item, dict) else item.index
467 score = item["relevance_score"] if isinstance(item, dict) else item.relevance_score
468 scores[idx] = float(score)
469 if isinstance(response, dict):
470 model = response.get("model")
471 else:
472 model = getattr(response, "model", None)
473 return RerankResult(scores=scores, model=model)
475 def list_models(self, *, base_url: str, api_key: str) -> list[str]:
476 """List models from Ollama (``/api/tags``) or an OpenAI-compatible ``/v1/models``."""
477 clean_base = base_url.rstrip("/")
478 spec = detect_local_server(clean_base)
479 if spec is OLLAMA:
480 return self._list_ollama_models(clean_base)
481 return self._list_openai_models(clean_base, api_key)
483 def list_chat_models(self, provider: str) -> list[str]:
484 """Return chat-mode model ids from litellm's static catalog.
486 Returns whatever litellm exposes for *provider*, alphabetically.
487 Empty list when litellm is not installed or the provider has no
488 chat-mode entries.
489 """
490 try:
491 import litellm
492 except ImportError:
493 return []
494 return self._all_chat_models_for(provider, litellm)
496 @staticmethod
497 def _all_chat_models_for(provider: str, litellm: Any) -> list[str]:
498 """Filter litellm's catalog down to chat-mode entries for ``provider``.
500 litellm's catalog stores some providers' models bare (``gpt-4o``)
501 and others prefixed (``mistral/codestral-latest``,
502 ``openrouter/anthropic/claude-3.5-sonnet``). Strip any leading
503 ``{provider}/`` so callers see uniformly bare names; the canonical
504 ``provider/name`` form is reapplied at the routing layer via
505 :meth:`ProviderModelRef.for_openai_prefix`.
506 """
507 models = litellm.models_by_provider.get(provider, set())
508 prefix = f"{provider}/"
509 bare: set[str] = set()
510 for model_name in models:
511 info = litellm.model_cost.get(model_name, {})
512 if info.get("mode") != "chat":
513 continue
514 bare.add(model_name.removeprefix(prefix))
515 return sorted(bare)
517 @staticmethod
518 def _list_ollama_models(base_url: str) -> list[str]:
519 """List models via the Ollama ``/api/tags`` endpoint."""
520 try:
521 resp = httpx.get(f"{base_url}/api/tags", timeout=DEFAULT_HTTP_TIMEOUT)
522 resp.raise_for_status()
523 data = resp.json()
524 return [m["name"] for m in data.get("models", [])]
525 except httpx.HTTPError as exc:
526 raise ProviderError(f"Cannot list models: {exc}", provider=_PROVIDER_NAME) from exc
528 @staticmethod
529 def _list_openai_models(base_url: str, api_key: str) -> list[str]:
530 """List models via an OpenAI-compatible ``/v1/models`` endpoint."""
531 headers: dict[str, str] = {}
532 if api_key:
533 headers["Authorization"] = f"Bearer {api_key}"
534 try:
535 resp = httpx.get(
536 openai_models_url(base_url), headers=headers, timeout=DEFAULT_HTTP_TIMEOUT
537 )
538 resp.raise_for_status()
539 data = resp.json()
540 return [m["id"] for m in data.get("data", [])]
541 except httpx.HTTPError:
542 log.debug("Failed to list models via /v1/models", exc_info=True)
543 return []
545 def pull_model(
546 self,
547 model: str,
548 *,
549 base_url: str,
550 on_progress: Callable[..., Any] | None = None,
551 ) -> None:
552 """Refuse to pull: local servers (Ollama, LM Studio) are read-only.
554 Their models are managed in their own app and surface here once
555 present, so lilbee never downloads them over the network.
556 """
557 spec = detect_local_server(base_url.rstrip("/"))
558 server = spec.display_name if spec is not None else "This server"
559 raise ProviderError(
560 f"{server} doesn't download models over the network. "
561 f"Add the model in its own app, then pick it here.",
562 provider=_PROVIDER_NAME,
563 )
565 def show_model(self, model: str, *, base_url: str) -> dict[str, Any] | None:
566 """Get model info via the Ollama ``/api/show`` endpoint.
568 Parses and caches per-model generation defaults from the
569 ``parameters`` field. Also extracts the ``capabilities`` list
570 (newer Ollama versions) so callers can check for vision support.
571 Returns ``None`` for servers without a metadata endpoint (LM Studio).
572 """
573 clean_base = base_url.rstrip("/")
574 spec = detect_local_server(clean_base)
575 if spec is None or not spec.supports_show:
576 return None
577 # Ollama's API uses bare model names; the routing-layer prefix has
578 # to come off before the request goes out.
579 ollama_name = model.removeprefix(OLLAMA.wire_prefix)
580 try:
581 resp = httpx.post(
582 f"{clean_base}/api/show",
583 json={"name": ollama_name},
584 timeout=DEFAULT_HTTP_TIMEOUT,
585 )
586 resp.raise_for_status()
587 data = resp.json()
588 except httpx.HTTPError:
589 return None
591 result: dict[str, Any] = {}
593 params = data.get("parameters", "")
594 if isinstance(params, str) and params:
595 _cache_ollama_defaults(model, params)
596 result["parameters"] = params
597 elif params:
598 _cache_ollama_defaults(model, str(params))
599 result["parameters"] = str(params)
601 capabilities = data.get("capabilities")
602 if isinstance(capabilities, list):
603 result["capabilities"] = capabilities
605 return result or None