Coverage for src / lilbee / providers / sdk_backend.py: 100%

70 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""Protocol and value types for SDK-backed LLM backends. 

2 

3A backend hides one third-party SDK. The ``SdkLLMProvider`` speaks to 

4backends exclusively through the ``LlmSdkBackend`` Protocol and the 

5value types defined here, so SDK response objects never leak outside 

6the adapter. 

7 

8This module is intentionally dependency-free (no SDK imports, no 

9lilbee provider imports beyond the shared base types). 

10""" 

11 

12from __future__ import annotations 

13 

14import os 

15from collections.abc import Callable, Iterator 

16from dataclasses import dataclass, field 

17from typing import TYPE_CHECKING, Any, Protocol 

18 

19# Display name for the active backend the SDK is talking to. The 

20# adapter's own identity is exposed separately via provider_name. 

21from lilbee.providers.backend_names import BackendName 

22from lilbee.providers.local_servers import detect_local_server 

23 

24if TYPE_CHECKING: 

25 # circular: sdk_backend -> model_ref -> types -> sdk_backend (annotation-only) 

26 from lilbee.providers.model_ref import ProviderModelRef 

27 

28# Single source of truth for per-provider API key configuration. 

29# Maps (provider_name, config_field, env_var, display_label). Backend-agnostic: 

30# OpenAI-compatible SDKs all read these env vars at call time. Tuple order 

31# is the canonical display order downstream consumers (TUI grouping, catalog 

32# sections) honor when surfacing providers. 

33PROVIDER_KEYS: tuple[tuple[str, str, str, str], ...] = ( 

34 ("openrouter", "openrouter_api_key", "OPENROUTER_API_KEY", "OpenRouter"), 

35 ("gemini", "gemini_api_key", "GEMINI_API_KEY", "Gemini"), 

36 ("anthropic", "anthropic_api_key", "ANTHROPIC_API_KEY", "Anthropic"), 

37 ("openai", "openai_api_key", "OPENAI_API_KEY", "OpenAI"), 

38 ("mistral", "mistral_api_key", "MISTRAL_API_KEY", "Mistral"), 

39 ("deepseek", "deepseek_api_key", "DEEPSEEK_API_KEY", "DeepSeek"), 

40) 

41 

42# Derived set of config field names (for checking which updates touch API keys). 

43API_KEY_FIELDS: frozenset[str] = frozenset(t[1] for t in PROVIDER_KEYS) 

44 

45# Provider name -> cfg attribute holding that provider's API key. 

46PROVIDER_API_KEY_FIELD: dict[str, str] = {prov: field for prov, field, *_ in PROVIDER_KEYS} 

47 

48 

49# Provider name -> the SDK's own env var (read at call time by the backend). 

50PROVIDER_API_KEY_ENV: dict[str, str] = {prov: env for prov, _field, env, *_ in PROVIDER_KEYS} 

51 

52 

53def get_provider_api_key(provider: str) -> str | None: 

54 """Return the configured API key for *provider*, or ``None`` if unknown / unset. 

55 

56 *provider* is the lowercase routing key from a parsed model ref (e.g. 

57 ``"openai"``). Returns ``None`` for unknown providers AND for known 

58 providers whose key is unconfigured; callers can distinguish via 

59 :data:`PROVIDER_API_KEY_FIELD`. Reads only the lilbee config field; use 

60 :func:`provider_has_key` to also honor the SDK's own env var. 

61 """ 

62 from lilbee.core.config import cfg 

63 

64 field = PROVIDER_API_KEY_FIELD.get(provider.lower()) 

65 if field is None: 

66 return None 

67 value = getattr(cfg, field) 

68 return value or None 

69 

70 

71def provider_has_key(provider: str) -> bool: 

72 """True if *provider* has a key via its standard env var or the lilbee config field.""" 

73 env_var = PROVIDER_API_KEY_ENV.get(provider.lower()) 

74 if env_var and os.environ.get(env_var): 

75 return True 

76 return get_provider_api_key(provider) is not None 

77 

78 

79# Hosted API providers identified by URL substring. Local OpenAI-compatible 

80# servers (Ollama, LM Studio) are matched ahead of this table via the 

81# local-servers registry, so they are not listed here. 

82_REMOTE_API_URL_PATTERNS: tuple[tuple[str, BackendName], ...] = ( 

83 ("openrouter", BackendName.OPENROUTER), 

84 ("openai", BackendName.OPENAI), 

85 ("anthropic", BackendName.ANTHROPIC), 

86 ("googleapis", BackendName.GEMINI), 

87 ("gemini", BackendName.GEMINI), 

88 ("mistral", BackendName.MISTRAL), 

89 ("deepseek", BackendName.DEEPSEEK), 

90) 

91 

92 

93def detect_backend_name(base_url: str) -> BackendName: 

94 """Return the display name of the backend behind ``base_url``. 

95 

96 Adapter-agnostic; any SDK implementation can delegate to this helper. 

97 Checks the local-server registry (Ollama, LM Studio) first, then the 

98 hosted-API URL patterns, and falls back to ``BackendName.REMOTE``. 

99 """ 

100 local = detect_local_server(base_url) 

101 if local is not None: 

102 return local.display_name 

103 url_lower = base_url.lower() 

104 for pattern, name in _REMOTE_API_URL_PATTERNS: 

105 if pattern in url_lower: 

106 return name 

107 return BackendName.REMOTE 

108 

109 

110@dataclass(frozen=True) 

111class CompletionResult: 

112 """Single-shot chat completion result returned by a backend.""" 

113 

114 content: str 

115 finish_reason: str | None = None 

116 model: str | None = None 

117 

118 

119@dataclass(frozen=True) 

120class StreamChunk: 

121 """One delta yielded during a streaming chat completion.""" 

122 

123 content: str 

124 finish_reason: str | None = None 

125 

126 

127@dataclass(frozen=True) 

128class EmbeddingResult: 

129 """Embedding vectors returned by a backend for a batch of inputs.""" 

130 

131 vectors: list[list[float]] 

132 model: str | None = None 

133 

134 

135@dataclass(frozen=True) 

136class CompletionRequest: 

137 """Backend-agnostic request for a single completion call. 

138 

139 ``ref`` carries the parsed model reference; the adapter converts it 

140 to the wire format its SDK expects. ``messages`` is the raw lilbee 

141 message list (may contain ``images`` bytes); the adapter formats it 

142 for its SDK. ``api_base`` is populated for local/Ollama deployments 

143 and omitted for API-hosted models. 

144 """ 

145 

146 ref: ProviderModelRef 

147 messages: list[dict[str, Any]] 

148 options: dict[str, Any] = field(default_factory=dict) 

149 api_base: str | None = None 

150 api_key: str | None = None 

151 

152 

153@dataclass(frozen=True) 

154class EmbeddingRequest: 

155 """Backend-agnostic request for an embedding call.""" 

156 

157 ref: ProviderModelRef 

158 inputs: list[str] 

159 api_base: str | None = None 

160 api_key: str | None = None 

161 

162 

163@dataclass(frozen=True) 

164class RerankRequest: 

165 """Backend-agnostic rerank request.""" 

166 

167 ref: ProviderModelRef 

168 query: str 

169 candidates: list[str] 

170 api_base: str | None = None 

171 api_key: str | None = None 

172 

173 

174@dataclass(frozen=True) 

175class RerankResult: 

176 """Rerank scores returned by a backend, one per candidate in input order.""" 

177 

178 scores: list[float] 

179 model: str | None = None 

180 

181 

182class LlmSdkBackend(Protocol): 

183 """Protocol every LLM SDK adapter must satisfy. 

184 

185 The provider calls these methods through the Protocol only; SDK 

186 response objects never cross the seam. Methods with a natural 

187 "not supported" signal are documented below. 

188 

189 Lifecycle: ``available()`` is the cheap install check called before 

190 any other method; ``configure_logging`` runs once at first use. 

191 ``complete`` / ``complete_stream`` / ``embed`` are the hot-path 

192 operations. ``list_models`` / ``list_chat_models`` / ``pull_model`` 

193 / ``show_model`` are catalog helpers and may raise 

194 ``NotImplementedError`` or return empty values when unsupported. 

195 

196 Error contract: implementations must raise only ``ProviderError`` or 

197 ``NotImplementedError`` from any method. ``SdkLLMProvider`` wraps any 

198 other exception at the seam; adapters should translate SDK-specific 

199 errors (httpx errors, third-party SDK exceptions) into 

200 ``ProviderError`` so the provider can pass them through. 

201 """ 

202 

203 @property 

204 def provider_name(self) -> str: 

205 """Stable identifier used when wrapping errors in ``ProviderError``.""" 

206 ... 

207 

208 def active_backend_name(self, base_url: str) -> str: 

209 """Return the display name of the backend the adapter is talking to. 

210 

211 ``"Ollama"`` for an Ollama URL, ``"OpenAI"`` for an OpenAI URL, 

212 etc.; unknown URLs fall back to ``"Remote"``. The adapter's own 

213 identity is exposed separately through ``provider_name``. 

214 """ 

215 ... 

216 

217 def available(self) -> bool: 

218 """Return True when the underlying SDK is importable.""" 

219 ... 

220 

221 def configure_logging(self, *, suppress_debug: bool) -> None: 

222 """Apply backend-level logging toggles (best-effort no-op if unsupported).""" 

223 ... 

224 

225 def complete(self, request: CompletionRequest) -> CompletionResult: 

226 """Run a single-shot chat completion.""" 

227 ... 

228 

229 def complete_stream(self, request: CompletionRequest) -> Iterator[StreamChunk]: 

230 """Run a streaming chat completion, yielding content chunks.""" 

231 ... 

232 

233 def embed(self, request: EmbeddingRequest) -> EmbeddingResult: 

234 """Embed a batch of inputs, returning one vector per input.""" 

235 ... 

236 

237 def rerank(self, request: RerankRequest) -> RerankResult: 

238 """Score *candidates* against *query*, returning one float per candidate. 

239 

240 Raise ``NotImplementedError`` if the backend has no rerank API. 

241 An empty ``request.candidates`` returns ``RerankResult([])`` 

242 without an SDK call. 

243 """ 

244 ... 

245 

246 def list_models(self, *, base_url: str, api_key: str) -> list[str]: 

247 """List model identifiers visible to the backend. Return [] if unsupported.""" 

248 ... 

249 

250 def list_chat_models(self, provider: str) -> list[str]: 

251 """List chat-mode models from the SDK's catalog for *provider*. 

252 

253 Returns the unfiltered upstream catalog. Backends without a 

254 notion of frontier providers return ``[]``. 

255 

256 Unlike ``list_models``, this is a static pricing/capability table, 

257 not a runtime HTTP probe. 

258 """ 

259 ... 

260 

261 def pull_model( 

262 self, 

263 model: str, 

264 *, 

265 base_url: str, 

266 on_progress: Callable[..., Any] | None = None, 

267 ) -> None: 

268 """Pull a model. Raise NotImplementedError if unsupported.""" 

269 ... 

270 

271 def show_model(self, model: str, *, base_url: str) -> dict[str, Any] | None: 

272 """Return model metadata dict or None if unsupported / not found.""" 

273 ...