Coverage for src / lilbee / providers / base.py: 100%

34 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""Base protocol and exceptions for LLM providers.""" 

2 

3from __future__ import annotations 

4 

5from collections.abc import Callable, Iterator 

6from enum import StrEnum 

7from pathlib import Path 

8from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, overload, runtime_checkable 

9 

10from pydantic import BaseModel 

11 

12if TYPE_CHECKING: 

13 from lilbee.providers.worker.transport import OcrBackend 

14 from lilbee.vision import PageText 

15 

16T_co = TypeVar("T_co", covariant=True) 

17 

18 

19@runtime_checkable 

20class ClosableIterator(Iterator[T_co], Protocol[T_co]): 

21 """An iterator that releases resources when ``close()`` is called. 

22 

23 Streaming chat responses use this to guarantee the upstream model lock 

24 is released even when callers truncate the stream before exhaustion. 

25 Generators satisfy this implicitly; explicit wrappers (e.g. the llama-cpp 

26 chat-lock iterator) implement it directly. 

27 """ 

28 

29 def close(self) -> None: ... 

30 

31 

32class LLMOptions(BaseModel): 

33 """Validated options passed to LLM providers. 

34 Only these fields are forwarded: everything else is rejected 

35 to prevent injection of sensitive parameters like api_base or api_key. 

36 """ 

37 

38 temperature: float | None = None 

39 top_p: float | None = None 

40 top_k: int | None = None 

41 seed: int | None = None 

42 num_predict: int | None = None 

43 repeat_penalty: float | None = None 

44 num_ctx: int | None = None 

45 

46 def to_dict(self) -> dict[str, Any]: 

47 """Return only non-None values as a dict.""" 

48 return {k: v for k, v in self.model_dump().items() if v is not None} 

49 

50 

51def filter_options(options: dict[str, Any]) -> dict[str, Any]: 

52 """Validate and filter generation options through LLMOptions model.""" 

53 return LLMOptions(**options).to_dict() 

54 

55 

56class ProviderErrorKind(StrEnum): 

57 """Provider-agnostic category of a failed provider call. 

58 

59 Classified by exception type at each backend boundary so callers can 

60 branch on the kind instead of matching message strings (which are 

61 provider-specific and drift between SDK versions). 

62 """ 

63 

64 AUTH = "auth" 

65 RATE_LIMIT = "rate_limit" 

66 CONTEXT_OVERFLOW = "context_overflow" 

67 NOT_FOUND = "not_found" 

68 BAD_REQUEST = "bad_request" 

69 CONNECTION = "connection" 

70 SERVER = "server" 

71 UNKNOWN = "unknown" 

72 

73 

74class ProviderError(Exception): 

75 """Raised when an LLM provider operation fails. 

76 

77 ``kind`` is the provider-agnostic category; backends that can't classify a 

78 failure leave it ``UNKNOWN``. 

79 """ 

80 

81 def __init__( 

82 self, 

83 message: str, 

84 *, 

85 provider: str = "", 

86 kind: ProviderErrorKind = ProviderErrorKind.UNKNOWN, 

87 ) -> None: 

88 self.provider = provider 

89 self.kind = kind 

90 super().__init__(message) 

91 

92 

93ChatMessage = dict[str, str] 

94 

95 

96class LLMProvider(Protocol): 

97 """Protocol for pluggable LLM backends.""" 

98 

99 def embed(self, texts: list[str]) -> list[list[float]]: 

100 """Embed a batch of texts, return list of vectors.""" 

101 ... 

102 

103 @overload 

104 def chat( 

105 self, 

106 messages: list[ChatMessage], 

107 *, 

108 stream: Literal[False] = False, 

109 options: dict[str, Any] | None = None, 

110 model: str | None = None, 

111 ) -> str: ... 

112 

113 @overload 

114 def chat( 

115 self, 

116 messages: list[ChatMessage], 

117 *, 

118 stream: Literal[True], 

119 options: dict[str, Any] | None = None, 

120 model: str | None = None, 

121 ) -> ClosableIterator[str]: ... 

122 

123 def chat( 

124 self, 

125 messages: list[ChatMessage], 

126 *, 

127 stream: bool = False, 

128 options: dict[str, Any] | None = None, 

129 model: str | None = None, 

130 ) -> str | ClosableIterator[str]: 

131 """Chat completion. Returns str for non-stream, ClosableIterator[str] for stream.""" 

132 ... 

133 

134 def vision_ocr( 

135 self, 

136 png_bytes: bytes, 

137 model: str, 

138 prompt: str = "", 

139 *, 

140 timeout: float | None = None, 

141 ) -> str: 

142 """OCR one page image; ``timeout`` seconds, ``None``/``0`` = no cap.""" 

143 ... 

144 

145 def pdf_ocr( 

146 self, 

147 path: Path, 

148 *, 

149 backend: OcrBackend, 

150 model: str = "", 

151 per_page_timeout_s: float | None = None, 

152 quiet: bool = True, 

153 on_progress: Callable[..., None] | None = None, 

154 ) -> list[PageText]: 

155 """OCR every page of a PDF, returning per-page text in input order. 

156 

157 Backends that cannot OCR scanned PDFs locally raise 

158 :class:`NotImplementedError`; ingest callers catch and log this. 

159 """ 

160 ... 

161 

162 def list_models(self) -> list[str]: 

163 """List available model identifiers.""" 

164 ... 

165 

166 def list_chat_models(self, provider: str) -> list[str]: 

167 """List frontier chat models the provider is aware of for *provider*. 

168 

169 Returns the unfiltered upstream catalog (whatever litellm 

170 exposes for API providers; an empty list for backends like 

171 native llama-cpp that have no notion of external catalogs). 

172 """ 

173 ... 

174 

175 def pull_model(self, model: str, *, on_progress: Callable[..., Any] | None = None) -> None: 

176 """Download a model. Raises NotImplementedError if not supported.""" 

177 ... 

178 

179 def show_model(self, model: str) -> dict[str, Any] | None: 

180 """Return model metadata, or None if backend doesn't expose it.""" 

181 ... 

182 

183 def get_capabilities(self, model: str) -> list[str]: 

184 """Return capability tags (e.g. ``["completion", "vision"]``) for *model*. 

185 

186 Returns an empty list when the backend does not support capability 

187 reporting or the model is not found. 

188 """ 

189 ... 

190 

191 def rerank(self, query: str, candidates: list[str]) -> list[float]: 

192 """Score *candidates* for their relevance to *query*, one float per candidate. 

193 

194 The backend resolves the reranker model from ``cfg.reranker_model``. 

195 Callers MUST check ``cfg.reranker_model`` is non-empty before 

196 calling; use :meth:`supports_rerank` for UI-render decisions. 

197 

198 Returns: list of floats in input order, higher = more relevant. 

199 Empty ``candidates`` returns ``[]``. 

200 Raises :class:`ProviderError` when the backend does not support 

201 reranking or ``cfg.reranker_model`` is empty. 

202 """ 

203 ... 

204 

205 def supports_rerank(self) -> bool: 

206 """Capability probe: can this backend rerank *if* a model is configured? 

207 

208 Pure capability check, NOT "a reranker is currently active". An 

209 empty ``cfg.reranker_model`` returns ``True`` so the settings UI 

210 keeps the picker visible; callers that need to know whether 

211 reranking is actually configured must check ``bool(cfg.reranker_model)`` 

212 separately. ``rerank()`` is the gated path that requires a 

213 non-empty value. 

214 """ 

215 return False 

216 

217 def shutdown(self) -> None: 

218 """Release resources (e.g. background threads). No-op if nothing to clean up.""" 

219 ... 

220 

221 def invalidate_load_cache(self, model_path: Path | None = None) -> None: 

222 """Drop loaded-model state; ``None`` evicts all, else only that path. No-op default.""" 

223 return 

224 

225 def warm_up_pool(self) -> None: 

226 """Eagerly register configured roles so :meth:`WorkerPool.start_eager` has work to do. 

227 

228 Default no-op so providers without a worker pool (SDK / routing 

229 wrappers) can be passed to ``Services`` unchanged. Implemented by 

230 :class:`LlamaCppProvider` to register chat / embed / rerank / vision 

231 roles whose model is configured. 

232 """ 

233 return