Coverage for src / lilbee / providers / litellm_sdk.py: 100%

270 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""litellm implementation of the ``LlmSdkBackend`` Protocol. 

2 

3This is the ONLY file in lilbee that imports ``litellm``. When migrating 

4to a different SDK (e.g. ``liter-llm``), add a sibling module alongside 

5this one and flip the single import in ``providers/factory.py``. 

6 

7All knowledge of the litellm wire format (``ollama/`` prefix, OpenAI 

8content-parts schema for images) lives here. The semantic layer in 

9``sdk_llm_provider`` never touches SDK-specific conventions. 

10""" 

11 

12from __future__ import annotations 

13 

14import base64 

15import functools 

16import json 

17import logging 

18from collections.abc import Callable, Iterator 

19from typing import Any 

20 

21import httpx 

22 

23from lilbee.core.config import DEFAULT_HTTP_TIMEOUT 

24from lilbee.providers.base import ProviderError 

25from lilbee.providers.model_ref import OLLAMA_PREFIX, ProviderModelRef 

26from lilbee.providers.sdk_backend import ( 

27 CompletionRequest, 

28 CompletionResult, 

29 EmbeddingRequest, 

30 EmbeddingResult, 

31 RerankRequest, 

32 RerankResult, 

33 StreamChunk, 

34 detect_backend_name, 

35) 

36 

37log = logging.getLogger(__name__) 

38 

39_PROVIDER_NAME = "litellm" 

40_OLLAMA_URL_PATTERNS = ("localhost:11434", "127.0.0.1:11434", "ollama") 

41 

42# Substrings dropped from the "LiteLLM" logger before they reach the user's 

43# terminal. Two classes of noise: (1) the model-cost-map fetch failure that 

44# LiteLLM logs at WARNING on every offline chat call, and (2) AWS-flavored 

45# advisories from sagemaker / bedrock / boto3 / botocore. lilbee's litellm 

46# extra deliberately excludes boto3, so the AWS warnings aren't actionable. 

47# Compared case-insensitively to catch the mixed-case variants LiteLLM emits. 

48_LITELLM_SUPPRESS_SUBSTRINGS = ( 

49 "failed to fetch remote model cost map", 

50 "boto3", 

51 "botocore", 

52 "sagemaker", 

53 "bedrock", 

54) 

55 

56 

57class _LitellmSubstringFilter(logging.Filter): 

58 """Drop ``LiteLLM`` log records whose message contains a suppressed substring.""" 

59 

60 def __init__(self, needles: tuple[str, ...]) -> None: 

61 super().__init__() 

62 self._needles = tuple(n.lower() for n in needles) 

63 

64 def filter(self, record: logging.LogRecord) -> bool: 

65 msg = record.getMessage().lower() 

66 return not any(n in msg for n in self._needles) 

67 

68 

69def install_litellm_log_filter() -> None: 

70 """Attach the ``LiteLLM`` substring filter to the package logger. 

71 

72 Called automatically when this module is imported (see the module-top 

73 invocation below) so the filter is in place before any litellm call 

74 can emit a warning. Exposed as a function so tests can re-apply after 

75 clearing the logger. 

76 """ 

77 logging.getLogger("LiteLLM").addFilter(_LitellmSubstringFilter(_LITELLM_SUPPRESS_SUBSTRINGS)) 

78 

79 

80# Install the filter at module import. lilbee never touches litellm before 

81# importing this module, so installing here always beats litellm's first 

82# warning to the punch. 

83install_litellm_log_filter() 

84 

85 

86class _LitellmResponseView: 

87 """Typed read-only view over a litellm completion-response object. 

88 

89 The litellm response shape is not in the SDK's type stubs. This 

90 adapter is the one place that knows how to pull ``model``, ``choices``, 

91 ``message_content`` and the streaming chunk fields out; SDK drift 

92 breaks here rather than across every caller. 

93 """ 

94 

95 def __init__(self, response: Any) -> None: 

96 self._response = response 

97 

98 @property 

99 def model(self) -> str | None: 

100 """The model name the SDK echoed back, if any.""" 

101 value = getattr(self._response, "model", None) 

102 return str(value) if value is not None else None 

103 

104 def _first_choice(self) -> Any: 

105 """First entry of the response's ``choices`` list, or ``None``.""" 

106 choices = getattr(self._response, "choices", None) or [] 

107 return choices[0] if choices else None 

108 

109 @property 

110 def message_content(self) -> str: 

111 """Content text of the first choice's message (non-stream path).""" 

112 choice = self._first_choice() 

113 if choice is None: 

114 return "" 

115 message = getattr(choice, "message", None) 

116 if message is None: 

117 return "" 

118 return getattr(message, "content", "") or "" 

119 

120 @property 

121 def delta_content(self) -> str: 

122 """Content delta of the first choice (stream-path chunk).""" 

123 choice = self._first_choice() 

124 if choice is None: 

125 return "" 

126 delta = getattr(choice, "delta", None) 

127 if delta is None: 

128 return "" 

129 return getattr(delta, "content", "") or "" 

130 

131 @property 

132 def finish_reason(self) -> str | None: 

133 """``finish_reason`` of the first choice, if the SDK populated it.""" 

134 choice = self._first_choice() 

135 return getattr(choice, "finish_reason", None) if choice is not None else None 

136 

137 

138def _is_ollama(base_url: str) -> bool: 

139 """Return True if *base_url* looks like an Ollama instance.""" 

140 url_lower = base_url.lower() 

141 return any(p in url_lower for p in _OLLAMA_URL_PATTERNS) 

142 

143 

144@functools.cache 

145def litellm_available() -> bool: 

146 """Return True if the ``litellm`` package is installed. 

147 

148 Uses ``importlib.util.find_spec`` rather than ``import litellm`` so the 

149 check stays fast on the UI thread. Executing ``litellm`` on Windows 

150 with Defender real-time scanning takes seconds (the package loads a 

151 long list of provider plugins on first import); the Settings screen 

152 builds synchronously and calls this in ``_FEATURE_GATED_GROUPS``, so 

153 a real import here blocks the entire TUI on the first Settings open. 

154 ``find_spec`` just walks ``sys.path`` to locate the package; the 

155 heavy import runs later, in worker threads or remote-call paths 

156 where the cost is expected. 

157 """ 

158 import importlib.util 

159 

160 return importlib.util.find_spec("litellm") is not None 

161 

162 

163_LITELLM_MISSING_MSG = ( 

164 "Remote and API models need the lilbee[litellm] extra. " 

165 "Reinstall with: uv tool install --prerelease=allow 'lilbee[litellm]'" 

166) 

167 

168 

169def _require_litellm() -> Any: 

170 """Import ``litellm`` or raise a user-facing ProviderError with install steps.""" 

171 try: 

172 import litellm 

173 except ImportError as exc: 

174 raise ProviderError(_LITELLM_MISSING_MSG, provider=_PROVIDER_NAME) from exc 

175 return litellm 

176 

177 

178def _cache_ollama_defaults(model: str, params_text: str) -> None: 

179 """Parse Ollama parameters and store in the model defaults cache.""" 

180 from lilbee.providers.model_defaults import parse_kv_parameters, set_defaults 

181 

182 defaults = parse_kv_parameters(params_text) 

183 set_defaults(model, defaults) 

184 

185 

186def _route_model(ref: ProviderModelRef, api_base: str | None) -> str: 

187 """Format *ref* for litellm using the OpenAI ``provider/model`` convention.""" 

188 if ref.is_api: 

189 return ref.for_openai_prefix() 

190 if api_base and _is_ollama(api_base): 

191 return f"{OLLAMA_PREFIX}{ref.name}" 

192 return ref.name 

193 

194 

195def _format_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: 

196 """Convert messages with inline image bytes into OpenAI content parts. 

197 

198 litellm routes to OpenAI-compatible endpoints that expect the 

199 ``{"type": "image_url", "image_url": {...}}`` content-parts schema 

200 for multimodal input. Messages without ``images`` pass through 

201 untouched. 

202 """ 

203 formatted: list[dict[str, Any]] = [] 

204 for msg in messages: 

205 if "images" in msg: 

206 content_parts: list[dict[str, Any]] = [{"type": "text", "text": msg.get("content", "")}] 

207 for img in msg["images"]: 

208 if isinstance(img, bytes): 

209 b64 = base64.b64encode(img).decode() 

210 content_parts.append( 

211 { 

212 "type": "image_url", 

213 "image_url": {"url": f"data:image/png;base64,{b64}"}, 

214 } 

215 ) 

216 formatted.append({"role": msg["role"], "content": content_parts}) 

217 else: 

218 formatted.append(msg) 

219 return formatted 

220 

221 

222class LitellmSdkBackend: 

223 """``LlmSdkBackend`` adapter backed by the ``litellm`` SDK.""" 

224 

225 @property 

226 def provider_name(self) -> str: 

227 """Stable identifier used when wrapping errors in ``ProviderError``.""" 

228 return _PROVIDER_NAME 

229 

230 def active_backend_name(self, base_url: str) -> str: 

231 """Return the display name of the backend ``base_url`` points at.""" 

232 return detect_backend_name(base_url) 

233 

234 def available(self) -> bool: 

235 """Return True if the underlying SDK is installed.""" 

236 return litellm_available() 

237 

238 def configure_logging(self, *, suppress_debug: bool) -> None: 

239 """Apply litellm's debug-info suppression toggle when requested.""" 

240 if not suppress_debug: 

241 return 

242 try: 

243 import litellm 

244 

245 litellm.suppress_debug_info = True 

246 except ImportError: 

247 pass # debug-suppression is best-effort when the litellm extra is absent 

248 

249 def complete(self, request: CompletionRequest) -> CompletionResult: 

250 """Run a single-shot completion through ``litellm.completion``.""" 

251 litellm = _require_litellm() 

252 kwargs = self._completion_kwargs(request, stream=False) 

253 try: 

254 response = litellm.completion(**kwargs) 

255 except Exception as exc: 

256 raise ProviderError(f"Chat failed: {exc}", provider=_PROVIDER_NAME) from exc 

257 view = _LitellmResponseView(response) 

258 return CompletionResult( 

259 content=view.message_content, 

260 finish_reason=view.finish_reason, 

261 model=view.model, 

262 ) 

263 

264 def complete_stream(self, request: CompletionRequest) -> Iterator[StreamChunk]: 

265 """Stream a completion through ``litellm.completion(stream=True)``.""" 

266 litellm = _require_litellm() 

267 kwargs = self._completion_kwargs(request, stream=True) 

268 try: 

269 response = litellm.completion(**kwargs) 

270 except Exception as exc: 

271 raise ProviderError(f"Chat failed: {exc}", provider=_PROVIDER_NAME) from exc 

272 return self._stream_chunks(response) 

273 

274 @staticmethod 

275 def _stream_chunks(response: Any) -> Iterator[StreamChunk]: 

276 """Yield ``StreamChunk`` values from a litellm streaming response. 

277 

278 Exceptions raised mid-iteration are wrapped in ``ProviderError`` 

279 so the semantic layer sees a consistent error type regardless of 

280 where the SDK failed. 

281 """ 

282 try: 

283 for chunk in response: 

284 view = _LitellmResponseView(chunk) 

285 content = view.delta_content 

286 finish_reason = view.finish_reason 

287 if content or finish_reason: 

288 yield StreamChunk(content=content, finish_reason=finish_reason) 

289 except ProviderError: 

290 raise 

291 except Exception as exc: 

292 raise ProviderError(f"Chat failed: {exc}", provider=_PROVIDER_NAME) from exc 

293 

294 @staticmethod 

295 def _completion_kwargs(request: CompletionRequest, *, stream: bool) -> dict[str, Any]: 

296 """Translate a ``CompletionRequest`` into litellm kwargs.""" 

297 kwargs: dict[str, Any] = { 

298 "model": _route_model(request.ref, request.api_base), 

299 "messages": _format_messages(request.messages), 

300 "stream": stream, 

301 } 

302 if request.api_base: 

303 kwargs["api_base"] = request.api_base 

304 if request.api_key: 

305 kwargs["api_key"] = request.api_key 

306 if request.options: 

307 kwargs.update(request.options) 

308 return kwargs 

309 

310 def embed(self, request: EmbeddingRequest) -> EmbeddingResult: 

311 """Embed inputs through ``litellm.embedding``.""" 

312 litellm = _require_litellm() 

313 kwargs: dict[str, Any] = { 

314 "model": _route_model(request.ref, request.api_base), 

315 "input": request.inputs, 

316 } 

317 if request.api_base: 

318 kwargs["api_base"] = request.api_base 

319 if request.api_key: 

320 kwargs["api_key"] = request.api_key 

321 try: 

322 response = litellm.embedding(**kwargs) 

323 except Exception as exc: 

324 raise ProviderError(f"Embedding failed: {exc}", provider=_PROVIDER_NAME) from exc 

325 data = response["data"] if isinstance(response, dict) else response.data 

326 vectors = [item["embedding"] for item in data] 

327 if isinstance(response, dict): 

328 model = response.get("model") 

329 else: 

330 model = getattr(response, "model", None) 

331 return EmbeddingResult(vectors=vectors, model=model) 

332 

333 def rerank(self, request: RerankRequest) -> RerankResult: 

334 """Rerank documents via ``litellm.rerank`` (Cohere, Voyage, Jina, Together, HF TEI). 

335 

336 The SDK returns results sorted by relevance; we restore input 

337 order via each result's ``index`` so scores line up with the 

338 caller's ``candidates`` list. 

339 """ 

340 if not request.candidates: 

341 return RerankResult(scores=[]) 

342 litellm = _require_litellm() 

343 kwargs: dict[str, Any] = { 

344 "model": _route_model(request.ref, request.api_base), 

345 "query": request.query, 

346 "documents": request.candidates, 

347 } 

348 if request.api_base: 

349 kwargs["api_base"] = request.api_base 

350 if request.api_key: 

351 kwargs["api_key"] = request.api_key 

352 try: 

353 response = litellm.rerank(**kwargs) 

354 except Exception as exc: 

355 raise ProviderError(f"Rerank failed: {exc}", provider=_PROVIDER_NAME) from exc 

356 results = response["results"] if isinstance(response, dict) else response.results 

357 scores = [0.0] * len(request.candidates) 

358 for item in results: 

359 idx = item["index"] if isinstance(item, dict) else item.index 

360 score = item["relevance_score"] if isinstance(item, dict) else item.relevance_score 

361 scores[idx] = float(score) 

362 if isinstance(response, dict): 

363 model = response.get("model") 

364 else: 

365 model = getattr(response, "model", None) 

366 return RerankResult(scores=scores, model=model) 

367 

368 def list_models(self, *, base_url: str, api_key: str) -> list[str]: 

369 """List models from Ollama or an OpenAI-compatible server.""" 

370 clean_base = base_url.rstrip("/") 

371 if _is_ollama(clean_base): 

372 return self._list_ollama_models(clean_base) 

373 return self._list_openai_models(clean_base, api_key) 

374 

375 def list_chat_models(self, provider: str) -> list[str]: 

376 """Return chat-mode model ids from litellm's static catalog. 

377 

378 Returns whatever litellm exposes for *provider*, alphabetically. 

379 Empty list when litellm is not installed or the provider has no 

380 chat-mode entries. 

381 """ 

382 try: 

383 import litellm 

384 except ImportError: 

385 return [] 

386 return self._all_chat_models_for(provider, litellm) 

387 

388 @staticmethod 

389 def _all_chat_models_for(provider: str, litellm: Any) -> list[str]: 

390 """Filter litellm's catalog down to chat-mode entries for ``provider``. 

391 

392 litellm's catalog stores some providers' models bare (``gpt-4o``) 

393 and others prefixed (``mistral/codestral-latest``, 

394 ``openrouter/anthropic/claude-3.5-sonnet``). Strip any leading 

395 ``{provider}/`` so callers see uniformly bare names; the canonical 

396 ``provider/name`` form is reapplied at the routing layer via 

397 :meth:`ProviderModelRef.for_openai_prefix`. 

398 """ 

399 models = litellm.models_by_provider.get(provider, set()) 

400 prefix = f"{provider}/" 

401 bare: set[str] = set() 

402 for model_name in models: 

403 info = litellm.model_cost.get(model_name, {}) 

404 if info.get("mode") != "chat": 

405 continue 

406 bare.add(model_name.removeprefix(prefix)) 

407 return sorted(bare) 

408 

409 @staticmethod 

410 def _list_ollama_models(base_url: str) -> list[str]: 

411 """List models via the Ollama ``/api/tags`` endpoint.""" 

412 try: 

413 resp = httpx.get(f"{base_url}/api/tags", timeout=DEFAULT_HTTP_TIMEOUT) 

414 resp.raise_for_status() 

415 data = resp.json() 

416 return [m["name"] for m in data.get("models", [])] 

417 except httpx.HTTPError as exc: 

418 raise ProviderError(f"Cannot list models: {exc}", provider=_PROVIDER_NAME) from exc 

419 

420 @staticmethod 

421 def _list_openai_models(base_url: str, api_key: str) -> list[str]: 

422 """List models via an OpenAI-compatible ``/v1/models`` endpoint.""" 

423 headers: dict[str, str] = {} 

424 if api_key: 

425 headers["Authorization"] = f"Bearer {api_key}" 

426 try: 

427 resp = httpx.get(f"{base_url}/v1/models", headers=headers, timeout=DEFAULT_HTTP_TIMEOUT) 

428 resp.raise_for_status() 

429 data = resp.json() 

430 return [m["id"] for m in data.get("data", [])] 

431 except httpx.HTTPError: 

432 log.debug("Failed to list models via /v1/models", exc_info=True) 

433 return [] 

434 

435 def pull_model( 

436 self, 

437 model: str, 

438 *, 

439 base_url: str, 

440 on_progress: Callable[..., Any] | None = None, 

441 ) -> None: 

442 """Pull a model via the Ollama ``/api/pull`` endpoint.""" 

443 clean_base = base_url.rstrip("/") 

444 try: 

445 with ( 

446 # Streaming Ollama /api/pull; unbounded read is intentional 

447 # since model downloads can exceed any wall-clock timeout. 

448 httpx.Client(timeout=None) as client, # noqa: S113 

449 client.stream( 

450 "POST", 

451 f"{clean_base}/api/pull", 

452 json={"name": model, "stream": True}, 

453 ) as resp, 

454 ): 

455 resp.raise_for_status() 

456 for line in resp.iter_lines(): 

457 if not line: 

458 continue 

459 event = json.loads(line) 

460 if on_progress: 

461 on_progress(event) 

462 if event.get("status") == "success": 

463 break 

464 except httpx.HTTPError as exc: 

465 raise ProviderError( 

466 f"Cannot pull model {model!r}: {exc}", provider=_PROVIDER_NAME 

467 ) from exc 

468 

469 def show_model(self, model: str, *, base_url: str) -> dict[str, Any] | None: 

470 """Get model info via the Ollama ``/api/show`` endpoint. 

471 

472 Parses and caches per-model generation defaults from the 

473 ``parameters`` field. Also extracts the ``capabilities`` list 

474 (newer Ollama versions) so callers can check for vision support. 

475 """ 

476 clean_base = base_url.rstrip("/") 

477 # Ollama's API uses bare model names; the routing-layer prefix has 

478 # to come off before the request goes out. 

479 ollama_name = model[len(OLLAMA_PREFIX) :] if model.startswith(OLLAMA_PREFIX) else model 

480 try: 

481 resp = httpx.post( 

482 f"{clean_base}/api/show", 

483 json={"name": ollama_name}, 

484 timeout=DEFAULT_HTTP_TIMEOUT, 

485 ) 

486 resp.raise_for_status() 

487 data = resp.json() 

488 except httpx.HTTPError: 

489 return None 

490 

491 result: dict[str, Any] = {} 

492 

493 params = data.get("parameters", "") 

494 if isinstance(params, str) and params: 

495 _cache_ollama_defaults(model, params) 

496 result["parameters"] = params 

497 elif params: 

498 _cache_ollama_defaults(model, str(params)) 

499 result["parameters"] = str(params) 

500 

501 capabilities = data.get("capabilities") 

502 if isinstance(capabilities, list): 

503 result["capabilities"] = capabilities 

504 

505 return result or None