Coverage for src / lilbee / providers / sdk_llm_provider.py: 100%

159 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""SDK-agnostic LLM provider implementing the public ``LLMProvider`` Protocol. 

2 

3``SdkLLMProvider`` owns the semantic layer: auth key injection, option 

4translation, model-ref parsing, error wrapping, and lazy one-shot 

5backend initialization (``configure_logging`` + ``inject_provider_keys`` 

6on first use). It speaks to the underlying SDK exclusively through an 

7``LlmSdkBackend``, so swapping SDKs is a one-file adapter change. 

8 

9Zero direct SDK imports live here. The adapter owns SDK-specific 

10concerns like wire-format prefixes (``ollama/``) and OpenAI content-parts 

11schema for image inputs. 

12""" 

13 

14from __future__ import annotations 

15 

16import logging 

17import os 

18from collections.abc import Callable 

19from pathlib import Path 

20from typing import Any, Literal, overload 

21 

22from lilbee.core.config import cfg 

23from lilbee.providers.base import ClosableIterator, LLMProvider, ProviderError 

24from lilbee.providers.local_servers import LOCAL_SERVER_KEYS 

25from lilbee.providers.local_servers.config_urls import base_url_for, configured_local_servers 

26from lilbee.providers.model_ref import ProviderModelRef, parse_model_ref, translate_options 

27from lilbee.providers.sdk_backend import ( 

28 PROVIDER_KEYS, 

29 CompletionRequest, 

30 EmbeddingRequest, 

31 LlmSdkBackend, 

32 RerankRequest, 

33) 

34from lilbee.providers.worker.transport import OcrBackend 

35from lilbee.vision import PageText 

36 

37log = logging.getLogger(__name__) 

38 

39 

40def _api_base_for(ref: ProviderModelRef) -> str | None: 

41 """Endpoint for a local-server ref; ``None`` for hosted APIs (no base needed).""" 

42 if ref.provider in LOCAL_SERVER_KEYS: 

43 return base_url_for(ref.provider) 

44 return None 

45 

46 

47def inject_provider_keys() -> None: 

48 """Copy per-provider API keys from config into ``os.environ``. 

49 

50 OpenAI-compatible SDKs read provider-specific env vars 

51 (``OPENAI_API_KEY``, ``ANTHROPIC_API_KEY``, ...) at call time. This 

52 bridges lilbee's config system to that convention. Explicit env 

53 vars are never overwritten so users can still override via their 

54 shell. 

55 """ 

56 for _, cfg_field, env_var, _ in PROVIDER_KEYS: 

57 value = getattr(cfg, cfg_field, "") 

58 if value and not os.environ.get(env_var): 

59 os.environ[env_var] = value 

60 

61 

62class SdkLLMProvider(LLMProvider): 

63 """Provider that delegates SDK calls to an ``LlmSdkBackend``.""" 

64 

65 def __init__( 

66 self, 

67 backend: LlmSdkBackend, 

68 *, 

69 api_key: str = "", 

70 ) -> None: 

71 self._backend = backend 

72 self._api_key = api_key 

73 self._initialized = False 

74 

75 def _ensure_initialized(self) -> None: 

76 """Apply one-shot backend setup before the first call. 

77 

78 Runs ``configure_logging(suppress_debug=cfg.json_mode)`` and 

79 ``inject_provider_keys()`` exactly once, regardless of whether 

80 the first operation is ``chat``, ``embed``, or a catalog query. 

81 Both steps happen together because the backend's first SDK 

82 import must see (a) the debug flag applied, and (b) per-provider 

83 API keys in ``os.environ``. 

84 """ 

85 if self._initialized: 

86 return 

87 try: 

88 self._backend.configure_logging(suppress_debug=cfg.json_mode) 

89 except (ImportError, AttributeError): 

90 log.debug("backend.configure_logging failed", exc_info=True) 

91 inject_provider_keys() 

92 self._initialized = True 

93 

94 def embed(self, texts: list[str]) -> list[list[float]]: 

95 """Embed texts via the configured backend.""" 

96 self._ensure_initialized() 

97 ref = parse_model_ref(cfg.embedding_model) 

98 request = EmbeddingRequest( 

99 ref=ref, 

100 inputs=texts, 

101 api_base=_api_base_for(ref), 

102 api_key=self._api_key or None, 

103 ) 

104 try: 

105 result = self._backend.embed(request) 

106 except ProviderError: 

107 raise 

108 except Exception as exc: 

109 raise ProviderError( 

110 f"Embedding failed: {exc}", provider=self._backend.provider_name 

111 ) from exc 

112 return result.vectors 

113 

114 @overload 

115 def chat( 

116 self, 

117 messages: list[dict[str, str]], 

118 *, 

119 stream: Literal[False] = False, 

120 options: dict[str, Any] | None = None, 

121 model: str | None = None, 

122 ) -> str: ... 

123 

124 @overload 

125 def chat( 

126 self, 

127 messages: list[dict[str, str]], 

128 *, 

129 stream: Literal[True], 

130 options: dict[str, Any] | None = None, 

131 model: str | None = None, 

132 ) -> ClosableIterator[str]: ... 

133 

134 def chat( 

135 self, 

136 messages: list[dict[str, str]], 

137 *, 

138 stream: bool = False, 

139 options: dict[str, Any] | None = None, 

140 model: str | None = None, 

141 ) -> str | ClosableIterator[str]: 

142 """Chat completion via the configured backend.""" 

143 self._ensure_initialized() 

144 ref = parse_model_ref(model or cfg.chat_model) 

145 translated = translate_options(options, ref) if options else {} 

146 request = CompletionRequest( 

147 ref=ref, 

148 messages=list(messages), 

149 options=translated, 

150 api_base=_api_base_for(ref), 

151 api_key=self._api_key or None, 

152 ) 

153 if stream: 

154 return self._chat_stream(request) 

155 try: 

156 result = self._backend.complete(request) 

157 except ProviderError: 

158 raise 

159 except Exception as exc: 

160 raise ProviderError( 

161 f"Chat failed: {exc}", provider=self._backend.provider_name 

162 ) from exc 

163 return result.content 

164 

165 def _chat_stream(self, request: CompletionRequest) -> ClosableIterator[str]: 

166 """Yield content tokens from a streaming completion. 

167 

168 Exceptions surfaced by the backend at either call time or during 

169 iteration are re-raised as ``ProviderError`` so callers always 

170 see a consistent error type. 

171 """ 

172 try: 

173 stream = self._backend.complete_stream(request) 

174 for chunk in stream: 

175 if chunk.content: 

176 yield chunk.content 

177 except ProviderError: 

178 raise 

179 except Exception as exc: 

180 raise ProviderError( 

181 f"Chat failed: {exc}", provider=self._backend.provider_name 

182 ) from exc 

183 

184 def vision_ocr( 

185 self, 

186 png_bytes: bytes, 

187 model: str, 

188 prompt: str = "", 

189 *, 

190 timeout: float | None = None, 

191 ) -> str: 

192 """OCR via a multipart chat completion; ``timeout`` enforced via thread pool.""" 

193 from lilbee.vision import OCR_PROMPT, build_vision_messages 

194 

195 messages = build_vision_messages(prompt or OCR_PROMPT, png_bytes) 

196 if timeout and timeout > 0: 

197 from concurrent.futures import ThreadPoolExecutor 

198 

199 with ThreadPoolExecutor(max_workers=1) as pool: 

200 future = pool.submit(self.chat, messages, stream=False, model=model) 

201 result = future.result(timeout=timeout) 

202 else: 

203 result = self.chat(messages, stream=False, model=model) 

204 if not isinstance(result, str): 

205 raise ProviderError( 

206 f"Vision OCR returned non-text response ({type(result).__name__}).", 

207 provider=self._backend.provider_name, 

208 ) 

209 return result 

210 

211 def pdf_ocr( 

212 self, 

213 path: Path, 

214 *, 

215 backend: OcrBackend, 

216 model: str = "", 

217 per_page_timeout_s: float | None = None, 

218 quiet: bool = True, 

219 on_progress: Callable[..., None] | None = None, 

220 ) -> list[PageText]: 

221 """SDK backend cannot rasterise PDFs locally; ingest callers fall back.""" 

222 del path, backend, model, per_page_timeout_s, quiet, on_progress 

223 raise NotImplementedError( 

224 "Hosted models do not support scanned-PDF OCR. " 

225 "Set LILBEE_VISION_MODEL to a local GGUF vision model to enable it." 

226 ) 

227 

228 def list_models(self) -> list[str]: 

229 """List models across every configured local server (empty list on SDK errors).""" 

230 names: list[str] = [] 

231 for _spec, base_url in configured_local_servers(): 

232 try: 

233 names.extend(self._backend.list_models(base_url=base_url, api_key=self._api_key)) 

234 except NotImplementedError: 

235 continue 

236 except ProviderError: 

237 raise 

238 except Exception as exc: 

239 raise ProviderError( 

240 f"Listing models failed: {exc}", provider=self._backend.provider_name 

241 ) from exc 

242 return names 

243 

244 def list_chat_models(self, provider: str) -> list[str]: 

245 """List frontier chat models known to the backend for *provider*. 

246 

247 Initializes the backend first so ``cfg.json_mode`` suppression is 

248 applied before the SDK import inside the backend runs. 

249 """ 

250 self._ensure_initialized() 

251 try: 

252 return self._backend.list_chat_models(provider) 

253 except NotImplementedError: 

254 return [] 

255 except ProviderError: 

256 raise 

257 except Exception as exc: 

258 raise ProviderError( 

259 f"Listing chat models failed: {exc}", provider=self._backend.provider_name 

260 ) from exc 

261 

262 def pull_model(self, model: str, *, on_progress: Callable[..., Any] | None = None) -> None: 

263 """Pull a model via the backend.""" 

264 try: 

265 base_url = _api_base_for(parse_model_ref(model)) or "" 

266 self._backend.pull_model(model, base_url=base_url, on_progress=on_progress) 

267 except NotImplementedError as exc: 

268 raise ProviderError( 

269 f"Cannot pull model {model!r}: backend does not support pulling", 

270 provider=self._backend.provider_name, 

271 ) from exc 

272 except ProviderError: 

273 raise 

274 except Exception as exc: 

275 raise ProviderError( 

276 f"Cannot pull model {model!r}: {exc}", provider=self._backend.provider_name 

277 ) from exc 

278 

279 def show_model(self, model: str) -> dict[str, Any] | None: 

280 """Return model metadata, or None when unsupported or not found.""" 

281 try: 

282 base_url = _api_base_for(parse_model_ref(model)) or "" 

283 return self._backend.show_model(model, base_url=base_url) 

284 except NotImplementedError: 

285 return None 

286 except ProviderError: 

287 raise 

288 except Exception as exc: 

289 raise ProviderError( 

290 f"Showing model {model!r} failed: {exc}", provider=self._backend.provider_name 

291 ) from exc 

292 

293 def get_capabilities(self, model: str) -> list[str]: 

294 """Return capability tags from ``show_model`` output, or ``[]``.""" 

295 info = self.show_model(model) 

296 if info is None: 

297 return [] 

298 caps = info.get("capabilities", []) 

299 return caps if isinstance(caps, list) else [] 

300 

301 def rerank(self, query: str, candidates: list[str]) -> list[float]: 

302 """Rerank candidates via the SDK backend using ``cfg.reranker_model``.""" 

303 if not candidates: 

304 return [] 

305 self._ensure_initialized() 

306 ref = parse_model_ref(cfg.reranker_model) 

307 request = RerankRequest( 

308 ref=ref, 

309 query=query, 

310 candidates=candidates, 

311 api_base=_api_base_for(ref), 

312 api_key=self._api_key or None, 

313 ) 

314 try: 

315 result = self._backend.rerank(request) 

316 except ProviderError: 

317 raise 

318 except Exception as exc: 

319 raise ProviderError( 

320 f"Rerank failed: {exc}", provider=self._backend.provider_name 

321 ) from exc 

322 return result.scores 

323 

324 def supports_rerank(self) -> bool: 

325 """SDK-backed rerank is available when the underlying SDK is importable.""" 

326 return self._backend.available() 

327 

328 def available(self) -> bool: 

329 """Return True when the configured SDK backend can service catalog calls.""" 

330 return self._backend.available() 

331 

332 def shutdown(self) -> None: 

333 """SDK-backed providers hold no lilbee-side resources.""" 

334 

335 def invalidate_load_cache(self, model_path: Path | None = None) -> None: 

336 """No-op: cloud backends have no local model cache to evict."""