Coverage for src / lilbee / providers / sdk_llm_provider.py: 100%

149 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""SDK-agnostic LLM provider implementing the public ``LLMProvider`` Protocol. 

2 

3``SdkLLMProvider`` owns the semantic layer: auth key injection, option 

4translation, model-ref parsing, error wrapping, and lazy one-shot 

5backend initialization (``configure_logging`` + ``inject_provider_keys`` 

6on first use). It speaks to the underlying SDK exclusively through an 

7``LlmSdkBackend``, so swapping SDKs is a one-file adapter change. 

8 

9Zero direct SDK imports live here. The adapter owns SDK-specific 

10concerns like wire-format prefixes (``ollama/``) and OpenAI content-parts 

11schema for image inputs. 

12""" 

13 

14from __future__ import annotations 

15 

16import logging 

17import os 

18from collections.abc import Callable 

19from pathlib import Path 

20from typing import Any, Literal, overload 

21 

22from lilbee.core.config import cfg 

23from lilbee.providers.base import ClosableIterator, LLMProvider, ProviderError 

24from lilbee.providers.model_ref import parse_model_ref, translate_options 

25from lilbee.providers.sdk_backend import ( 

26 PROVIDER_KEYS, 

27 CompletionRequest, 

28 EmbeddingRequest, 

29 LlmSdkBackend, 

30 RerankRequest, 

31) 

32from lilbee.providers.worker.transport import OcrBackend 

33from lilbee.vision import PageText 

34 

35log = logging.getLogger(__name__) 

36 

37 

38def inject_provider_keys() -> None: 

39 """Copy per-provider API keys from config into ``os.environ``. 

40 

41 OpenAI-compatible SDKs read provider-specific env vars 

42 (``OPENAI_API_KEY``, ``ANTHROPIC_API_KEY``, ...) at call time. This 

43 bridges lilbee's config system to that convention. Explicit env 

44 vars are never overwritten so users can still override via their 

45 shell. 

46 """ 

47 for _, cfg_field, env_var, _ in PROVIDER_KEYS: 

48 value = getattr(cfg, cfg_field, "") 

49 if value and not os.environ.get(env_var): 

50 os.environ[env_var] = value 

51 

52 

53class SdkLLMProvider(LLMProvider): 

54 """Provider that delegates SDK calls to an ``LlmSdkBackend``.""" 

55 

56 def __init__( 

57 self, 

58 backend: LlmSdkBackend, 

59 *, 

60 base_url: str = "http://localhost:11434", 

61 api_key: str = "", 

62 ) -> None: 

63 self._backend = backend 

64 self._base_url = base_url.rstrip("/") 

65 self._api_key = api_key 

66 self._initialized = False 

67 

68 def _ensure_initialized(self) -> None: 

69 """Apply one-shot backend setup before the first call. 

70 

71 Runs ``configure_logging(suppress_debug=cfg.json_mode)`` and 

72 ``inject_provider_keys()`` exactly once, regardless of whether 

73 the first operation is ``chat``, ``embed``, or a catalog query. 

74 Both steps happen together because the backend's first SDK 

75 import must see (a) the debug flag applied, and (b) per-provider 

76 API keys in ``os.environ``. 

77 """ 

78 if self._initialized: 

79 return 

80 try: 

81 self._backend.configure_logging(suppress_debug=cfg.json_mode) 

82 except (ImportError, AttributeError): 

83 log.debug("backend.configure_logging failed", exc_info=True) 

84 inject_provider_keys() 

85 self._initialized = True 

86 

87 def embed(self, texts: list[str]) -> list[list[float]]: 

88 """Embed texts via the configured backend.""" 

89 self._ensure_initialized() 

90 ref = parse_model_ref(cfg.embedding_model) 

91 request = EmbeddingRequest( 

92 ref=ref, 

93 inputs=texts, 

94 api_base=self._base_url if ref.needs_api_base else None, 

95 api_key=self._api_key or None, 

96 ) 

97 try: 

98 result = self._backend.embed(request) 

99 except ProviderError: 

100 raise 

101 except Exception as exc: 

102 raise ProviderError( 

103 f"Embedding failed: {exc}", provider=self._backend.provider_name 

104 ) from exc 

105 return result.vectors 

106 

107 @overload 

108 def chat( 

109 self, 

110 messages: list[dict[str, str]], 

111 *, 

112 stream: Literal[False] = False, 

113 options: dict[str, Any] | None = None, 

114 model: str | None = None, 

115 ) -> str: ... 

116 

117 @overload 

118 def chat( 

119 self, 

120 messages: list[dict[str, str]], 

121 *, 

122 stream: Literal[True], 

123 options: dict[str, Any] | None = None, 

124 model: str | None = None, 

125 ) -> ClosableIterator[str]: ... 

126 

127 def chat( 

128 self, 

129 messages: list[dict[str, str]], 

130 *, 

131 stream: bool = False, 

132 options: dict[str, Any] | None = None, 

133 model: str | None = None, 

134 ) -> str | ClosableIterator[str]: 

135 """Chat completion via the configured backend.""" 

136 self._ensure_initialized() 

137 ref = parse_model_ref(model or cfg.chat_model) 

138 translated = translate_options(options, ref) if options else {} 

139 request = CompletionRequest( 

140 ref=ref, 

141 messages=list(messages), 

142 options=translated, 

143 api_base=self._base_url if ref.needs_api_base else None, 

144 api_key=self._api_key or None, 

145 ) 

146 if stream: 

147 return self._chat_stream(request) 

148 try: 

149 result = self._backend.complete(request) 

150 except ProviderError: 

151 raise 

152 except Exception as exc: 

153 raise ProviderError( 

154 f"Chat failed: {exc}", provider=self._backend.provider_name 

155 ) from exc 

156 return result.content 

157 

158 def _chat_stream(self, request: CompletionRequest) -> ClosableIterator[str]: 

159 """Yield content tokens from a streaming completion. 

160 

161 Exceptions surfaced by the backend at either call time or during 

162 iteration are re-raised as ``ProviderError`` so callers always 

163 see a consistent error type. 

164 """ 

165 try: 

166 stream = self._backend.complete_stream(request) 

167 for chunk in stream: 

168 if chunk.content: 

169 yield chunk.content 

170 except ProviderError: 

171 raise 

172 except Exception as exc: 

173 raise ProviderError( 

174 f"Chat failed: {exc}", provider=self._backend.provider_name 

175 ) from exc 

176 

177 def vision_ocr( 

178 self, 

179 png_bytes: bytes, 

180 model: str, 

181 prompt: str = "", 

182 *, 

183 timeout: float | None = None, 

184 ) -> str: 

185 """OCR via a multipart chat completion; ``timeout`` enforced via thread pool.""" 

186 from lilbee.vision import OCR_PROMPT, build_vision_messages 

187 

188 messages = build_vision_messages(prompt or OCR_PROMPT, png_bytes) 

189 if timeout and timeout > 0: 

190 from concurrent.futures import ThreadPoolExecutor 

191 

192 with ThreadPoolExecutor(max_workers=1) as pool: 

193 future = pool.submit(self.chat, messages, stream=False, model=model) 

194 result = future.result(timeout=timeout) 

195 else: 

196 result = self.chat(messages, stream=False, model=model) 

197 if not isinstance(result, str): 

198 raise ProviderError( 

199 f"Vision OCR returned non-text response ({type(result).__name__}).", 

200 provider=self._backend.provider_name, 

201 ) 

202 return result 

203 

204 def pdf_ocr( 

205 self, 

206 path: Path, 

207 *, 

208 backend: OcrBackend, 

209 model: str = "", 

210 per_page_timeout_s: float | None = None, 

211 quiet: bool = True, 

212 on_progress: Callable[..., None] | None = None, 

213 ) -> list[PageText]: 

214 """SDK backend cannot rasterise PDFs locally; ingest callers fall back.""" 

215 del path, backend, model, per_page_timeout_s, quiet, on_progress 

216 raise NotImplementedError( 

217 "Hosted models do not support scanned-PDF OCR. " 

218 "Set LILBEE_VISION_MODEL to a local GGUF vision model to enable it." 

219 ) 

220 

221 def list_models(self) -> list[str]: 

222 """List models from the backend (empty list on SDK errors).""" 

223 try: 

224 return self._backend.list_models(base_url=self._base_url, api_key=self._api_key) 

225 except NotImplementedError: 

226 return [] 

227 except ProviderError: 

228 raise 

229 except Exception as exc: 

230 raise ProviderError( 

231 f"Listing models failed: {exc}", provider=self._backend.provider_name 

232 ) from exc 

233 

234 def list_chat_models(self, provider: str) -> list[str]: 

235 """List frontier chat models known to the backend for *provider*. 

236 

237 Initializes the backend first so ``cfg.json_mode`` suppression is 

238 applied before the SDK import inside the backend runs. 

239 """ 

240 self._ensure_initialized() 

241 try: 

242 return self._backend.list_chat_models(provider) 

243 except NotImplementedError: 

244 return [] 

245 except ProviderError: 

246 raise 

247 except Exception as exc: 

248 raise ProviderError( 

249 f"Listing chat models failed: {exc}", provider=self._backend.provider_name 

250 ) from exc 

251 

252 def pull_model(self, model: str, *, on_progress: Callable[..., Any] | None = None) -> None: 

253 """Pull a model via the backend.""" 

254 try: 

255 self._backend.pull_model(model, base_url=self._base_url, on_progress=on_progress) 

256 except NotImplementedError as exc: 

257 raise ProviderError( 

258 f"Cannot pull model {model!r}: backend does not support pulling", 

259 provider=self._backend.provider_name, 

260 ) from exc 

261 except ProviderError: 

262 raise 

263 except Exception as exc: 

264 raise ProviderError( 

265 f"Cannot pull model {model!r}: {exc}", provider=self._backend.provider_name 

266 ) from exc 

267 

268 def show_model(self, model: str) -> dict[str, Any] | None: 

269 """Return model metadata, or None when unsupported or not found.""" 

270 try: 

271 return self._backend.show_model(model, base_url=self._base_url) 

272 except NotImplementedError: 

273 return None 

274 except ProviderError: 

275 raise 

276 except Exception as exc: 

277 raise ProviderError( 

278 f"Showing model {model!r} failed: {exc}", provider=self._backend.provider_name 

279 ) from exc 

280 

281 def get_capabilities(self, model: str) -> list[str]: 

282 """Return capability tags from ``show_model`` output, or ``[]``.""" 

283 info = self.show_model(model) 

284 if info is None: 

285 return [] 

286 caps = info.get("capabilities", []) 

287 return caps if isinstance(caps, list) else [] 

288 

289 def rerank(self, query: str, candidates: list[str]) -> list[float]: 

290 """Rerank candidates via the SDK backend using ``cfg.reranker_model``.""" 

291 if not candidates: 

292 return [] 

293 self._ensure_initialized() 

294 ref = parse_model_ref(cfg.reranker_model) 

295 request = RerankRequest( 

296 ref=ref, 

297 query=query, 

298 candidates=candidates, 

299 api_base=self._base_url if ref.needs_api_base else None, 

300 api_key=self._api_key or None, 

301 ) 

302 try: 

303 result = self._backend.rerank(request) 

304 except ProviderError: 

305 raise 

306 except Exception as exc: 

307 raise ProviderError( 

308 f"Rerank failed: {exc}", provider=self._backend.provider_name 

309 ) from exc 

310 return result.scores 

311 

312 def supports_rerank(self) -> bool: 

313 """SDK-backed rerank is available when the underlying SDK is importable.""" 

314 return self._backend.available() 

315 

316 def available(self) -> bool: 

317 """Return True when the configured SDK backend can service catalog calls.""" 

318 return self._backend.available() 

319 

320 def shutdown(self) -> None: 

321 """SDK-backed providers hold no lilbee-side resources.""" 

322 

323 def invalidate_load_cache(self, model_path: Path | None = None) -> None: 

324 """No-op: cloud backends have no local model cache to evict."""