Coverage for src / lilbee / providers / sdk_backend.py: 100%

59 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Protocol and value types for SDK-backed LLM backends. 

2 

3A backend hides one third-party SDK. The ``SdkLLMProvider`` speaks to 

4backends exclusively through the ``LlmSdkBackend`` Protocol and the 

5value types defined here, so SDK response objects never leak outside 

6the adapter. 

7 

8This module is intentionally dependency-free (no SDK imports, no 

9lilbee provider imports beyond the shared base types). 

10""" 

11 

12from __future__ import annotations 

13 

14from collections.abc import Callable, Iterator 

15from dataclasses import dataclass, field 

16from typing import TYPE_CHECKING, Any, Protocol 

17 

18# Display name for the active backend the SDK is talking to. The 

19# adapter's own identity is exposed separately via provider_name. 

20from lilbee.providers.backend_names import BackendName 

21 

22if TYPE_CHECKING: 

23 # circular: sdk_backend -> model_ref -> types -> sdk_backend (annotation-only) 

24 from lilbee.providers.model_ref import ProviderModelRef 

25 

26# Single source of truth for per-provider API key configuration. 

27# Maps (provider_name, config_field, env_var, display_label). Backend-agnostic: 

28# OpenAI-compatible SDKs all read these env vars at call time. Tuple order 

29# is the canonical display order downstream consumers (TUI grouping, catalog 

30# sections) honor when surfacing providers. 

31PROVIDER_KEYS: tuple[tuple[str, str, str, str], ...] = ( 

32 ("openrouter", "openrouter_api_key", "OPENROUTER_API_KEY", "OpenRouter"), 

33 ("gemini", "gemini_api_key", "GEMINI_API_KEY", "Gemini"), 

34 ("anthropic", "anthropic_api_key", "ANTHROPIC_API_KEY", "Anthropic"), 

35 ("openai", "openai_api_key", "OPENAI_API_KEY", "OpenAI"), 

36 ("mistral", "mistral_api_key", "MISTRAL_API_KEY", "Mistral"), 

37 ("deepseek", "deepseek_api_key", "DEEPSEEK_API_KEY", "DeepSeek"), 

38) 

39 

40# Derived set of config field names (for checking which updates touch API keys). 

41API_KEY_FIELDS: frozenset[str] = frozenset(t[1] for t in PROVIDER_KEYS) 

42 

43# Provider name -> cfg attribute holding that provider's API key. 

44PROVIDER_API_KEY_FIELD: dict[str, str] = {prov: field for prov, field, *_ in PROVIDER_KEYS} 

45 

46 

47def get_provider_api_key(provider: str) -> str | None: 

48 """Return the configured API key for *provider*, or ``None`` if unknown / unset. 

49 

50 *provider* is the lowercase routing key from a parsed model ref (e.g. 

51 ``"openai"``). Returns ``None`` for unknown providers AND for known 

52 providers whose key is unconfigured; callers can distinguish via 

53 :data:`PROVIDER_API_KEY_FIELD`. 

54 """ 

55 from lilbee.core.config import cfg 

56 

57 field = PROVIDER_API_KEY_FIELD.get(provider.lower()) 

58 if field is None: 

59 return None 

60 value = getattr(cfg, field) 

61 return value or None 

62 

63 

64_BACKEND_URL_PATTERNS: tuple[tuple[str, BackendName], ...] = ( 

65 ("localhost:11434", BackendName.OLLAMA), 

66 ("ollama", BackendName.OLLAMA), 

67 ("openrouter", BackendName.OPENROUTER), 

68 ("openai", BackendName.OPENAI), 

69 ("anthropic", BackendName.ANTHROPIC), 

70 ("googleapis", BackendName.GEMINI), 

71 ("gemini", BackendName.GEMINI), 

72 ("mistral", BackendName.MISTRAL), 

73 ("deepseek", BackendName.DEEPSEEK), 

74) 

75 

76 

77def detect_backend_name(base_url: str) -> BackendName: 

78 """Return the display name of the backend behind ``base_url``. 

79 

80 Adapter-agnostic; any SDK implementation can delegate to this helper. 

81 Falls back to ``BackendName.REMOTE`` when the URL matches none of 

82 the known patterns. 

83 """ 

84 url_lower = base_url.lower() 

85 for pattern, name in _BACKEND_URL_PATTERNS: 

86 if pattern in url_lower: 

87 return name 

88 return BackendName.REMOTE 

89 

90 

91@dataclass(frozen=True) 

92class CompletionResult: 

93 """Single-shot chat completion result returned by a backend.""" 

94 

95 content: str 

96 finish_reason: str | None = None 

97 model: str | None = None 

98 

99 

100@dataclass(frozen=True) 

101class StreamChunk: 

102 """One delta yielded during a streaming chat completion.""" 

103 

104 content: str 

105 finish_reason: str | None = None 

106 

107 

108@dataclass(frozen=True) 

109class EmbeddingResult: 

110 """Embedding vectors returned by a backend for a batch of inputs.""" 

111 

112 vectors: list[list[float]] 

113 model: str | None = None 

114 

115 

116@dataclass(frozen=True) 

117class CompletionRequest: 

118 """Backend-agnostic request for a single completion call. 

119 

120 ``ref`` carries the parsed model reference; the adapter converts it 

121 to the wire format its SDK expects. ``messages`` is the raw lilbee 

122 message list (may contain ``images`` bytes); the adapter formats it 

123 for its SDK. ``api_base`` is populated for local/Ollama deployments 

124 and omitted for API-hosted models. 

125 """ 

126 

127 ref: ProviderModelRef 

128 messages: list[dict[str, Any]] 

129 options: dict[str, Any] = field(default_factory=dict) 

130 api_base: str | None = None 

131 api_key: str | None = None 

132 

133 

134@dataclass(frozen=True) 

135class EmbeddingRequest: 

136 """Backend-agnostic request for an embedding call.""" 

137 

138 ref: ProviderModelRef 

139 inputs: list[str] 

140 api_base: str | None = None 

141 api_key: str | None = None 

142 

143 

144@dataclass(frozen=True) 

145class RerankRequest: 

146 """Backend-agnostic rerank request.""" 

147 

148 ref: ProviderModelRef 

149 query: str 

150 candidates: list[str] 

151 api_base: str | None = None 

152 api_key: str | None = None 

153 

154 

155@dataclass(frozen=True) 

156class RerankResult: 

157 """Rerank scores returned by a backend, one per candidate in input order.""" 

158 

159 scores: list[float] 

160 model: str | None = None 

161 

162 

163class LlmSdkBackend(Protocol): 

164 """Protocol every LLM SDK adapter must satisfy. 

165 

166 The provider calls these methods through the Protocol only; SDK 

167 response objects never cross the seam. Methods with a natural 

168 "not supported" signal are documented below. 

169 

170 Lifecycle: ``available()`` is the cheap install check called before 

171 any other method; ``configure_logging`` runs once at first use. 

172 ``complete`` / ``complete_stream`` / ``embed`` are the hot-path 

173 operations. ``list_models`` / ``list_chat_models`` / ``pull_model`` 

174 / ``show_model`` are catalog helpers and may raise 

175 ``NotImplementedError`` or return empty values when unsupported. 

176 

177 Error contract: implementations must raise only ``ProviderError`` or 

178 ``NotImplementedError`` from any method. ``SdkLLMProvider`` wraps any 

179 other exception at the seam; adapters should translate SDK-specific 

180 errors (httpx errors, third-party SDK exceptions) into 

181 ``ProviderError`` so the provider can pass them through. 

182 """ 

183 

184 @property 

185 def provider_name(self) -> str: 

186 """Stable identifier used when wrapping errors in ``ProviderError``.""" 

187 ... 

188 

189 def active_backend_name(self, base_url: str) -> str: 

190 """Return the display name of the backend the adapter is talking to. 

191 

192 ``"Ollama"`` for an Ollama URL, ``"OpenAI"`` for an OpenAI URL, 

193 etc.; unknown URLs fall back to ``"Remote"``. The adapter's own 

194 identity is exposed separately through ``provider_name``. 

195 """ 

196 ... 

197 

198 def available(self) -> bool: 

199 """Return True when the underlying SDK is importable.""" 

200 ... 

201 

202 def configure_logging(self, *, suppress_debug: bool) -> None: 

203 """Apply backend-level logging toggles (best-effort no-op if unsupported).""" 

204 ... 

205 

206 def complete(self, request: CompletionRequest) -> CompletionResult: 

207 """Run a single-shot chat completion.""" 

208 ... 

209 

210 def complete_stream(self, request: CompletionRequest) -> Iterator[StreamChunk]: 

211 """Run a streaming chat completion, yielding content chunks.""" 

212 ... 

213 

214 def embed(self, request: EmbeddingRequest) -> EmbeddingResult: 

215 """Embed a batch of inputs, returning one vector per input.""" 

216 ... 

217 

218 def rerank(self, request: RerankRequest) -> RerankResult: 

219 """Score *candidates* against *query*, returning one float per candidate. 

220 

221 Raise ``NotImplementedError`` if the backend has no rerank API. 

222 An empty ``request.candidates`` returns ``RerankResult([])`` 

223 without an SDK call. 

224 """ 

225 ... 

226 

227 def list_models(self, *, base_url: str, api_key: str) -> list[str]: 

228 """List model identifiers visible to the backend. Return [] if unsupported.""" 

229 ... 

230 

231 def list_chat_models(self, provider: str) -> list[str]: 

232 """List chat-mode models from the SDK's catalog for *provider*. 

233 

234 Returns the unfiltered upstream catalog. Backends without a 

235 notion of frontier providers return ``[]``. 

236 

237 Unlike ``list_models``, this is a static pricing/capability table, 

238 not a runtime HTTP probe. 

239 """ 

240 ... 

241 

242 def pull_model( 

243 self, 

244 model: str, 

245 *, 

246 base_url: str, 

247 on_progress: Callable[..., Any] | None = None, 

248 ) -> None: 

249 """Pull a model. Raise NotImplementedError if unsupported.""" 

250 ... 

251 

252 def show_model(self, model: str, *, base_url: str) -> dict[str, Any] | None: 

253 """Return model metadata dict or None if unsupported / not found.""" 

254 ...