Coverage for src / lilbee / providers / base.py: 100%

23 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-15 20:55 +0000

1"""Base protocol and exceptions for LLM providers.""" 

2 

3from __future__ import annotations 

4 

5from collections.abc import Callable, Iterator 

6from pathlib import Path 

7from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, overload, runtime_checkable 

8 

9from pydantic import BaseModel 

10 

11if TYPE_CHECKING: 

12 from lilbee.providers.worker.transport import OcrBackend 

13 from lilbee.vision import PageText 

14 

15T_co = TypeVar("T_co", covariant=True) 

16 

17 

18@runtime_checkable 

19class ClosableIterator(Iterator[T_co], Protocol[T_co]): 

20 """An iterator that releases resources when ``close()`` is called. 

21 

22 Streaming chat responses use this to guarantee the upstream model lock 

23 is released even when callers truncate the stream before exhaustion. 

24 Generators satisfy this implicitly; explicit wrappers (e.g. the llama-cpp 

25 chat-lock iterator) implement it directly. 

26 """ 

27 

28 def close(self) -> None: ... 

29 

30 

31class LLMOptions(BaseModel): 

32 """Validated options passed to LLM providers. 

33 Only these fields are forwarded: everything else is rejected 

34 to prevent injection of sensitive parameters like api_base or api_key. 

35 """ 

36 

37 temperature: float | None = None 

38 top_p: float | None = None 

39 top_k: int | None = None 

40 seed: int | None = None 

41 num_predict: int | None = None 

42 repeat_penalty: float | None = None 

43 num_ctx: int | None = None 

44 

45 def to_dict(self) -> dict[str, Any]: 

46 """Return only non-None values as a dict.""" 

47 return {k: v for k, v in self.model_dump().items() if v is not None} 

48 

49 

50def filter_options(options: dict[str, Any]) -> dict[str, Any]: 

51 """Validate and filter generation options through LLMOptions model.""" 

52 return LLMOptions(**options).to_dict() 

53 

54 

55class ProviderError(Exception): 

56 """Raised when an LLM provider operation fails.""" 

57 

58 def __init__(self, message: str, *, provider: str = "") -> None: 

59 self.provider = provider 

60 super().__init__(message) 

61 

62 

63ChatMessage = dict[str, str] 

64 

65 

66class LLMProvider(Protocol): 

67 """Protocol for pluggable LLM backends.""" 

68 

69 def embed(self, texts: list[str]) -> list[list[float]]: 

70 """Embed a batch of texts, return list of vectors.""" 

71 ... 

72 

73 @overload 

74 def chat( 

75 self, 

76 messages: list[ChatMessage], 

77 *, 

78 stream: Literal[False] = False, 

79 options: dict[str, Any] | None = None, 

80 model: str | None = None, 

81 ) -> str: ... 

82 

83 @overload 

84 def chat( 

85 self, 

86 messages: list[ChatMessage], 

87 *, 

88 stream: Literal[True], 

89 options: dict[str, Any] | None = None, 

90 model: str | None = None, 

91 ) -> ClosableIterator[str]: ... 

92 

93 def chat( 

94 self, 

95 messages: list[ChatMessage], 

96 *, 

97 stream: bool = False, 

98 options: dict[str, Any] | None = None, 

99 model: str | None = None, 

100 ) -> str | ClosableIterator[str]: 

101 """Chat completion. Returns str for non-stream, ClosableIterator[str] for stream.""" 

102 ... 

103 

104 def vision_ocr( 

105 self, 

106 png_bytes: bytes, 

107 model: str, 

108 prompt: str = "", 

109 *, 

110 timeout: float | None = None, 

111 ) -> str: 

112 """OCR one page image; ``timeout`` seconds, ``None``/``0`` = no cap.""" 

113 ... 

114 

115 def pdf_ocr( 

116 self, 

117 path: Path, 

118 *, 

119 backend: OcrBackend, 

120 model: str = "", 

121 per_page_timeout_s: float | None = None, 

122 quiet: bool = True, 

123 on_progress: Callable[..., None] | None = None, 

124 ) -> list[PageText]: 

125 """OCR every page of a PDF, returning per-page text in input order. 

126 

127 Backends that cannot OCR scanned PDFs locally raise 

128 :class:`NotImplementedError`; ingest callers catch and log this. 

129 """ 

130 ... 

131 

132 def list_models(self) -> list[str]: 

133 """List available model identifiers.""" 

134 ... 

135 

136 def list_chat_models(self, provider: str) -> list[str]: 

137 """List frontier chat models the provider is aware of for *provider*. 

138 

139 Returns the unfiltered upstream catalog (whatever litellm 

140 exposes for API providers; an empty list for backends like 

141 native llama-cpp that have no notion of external catalogs). 

142 """ 

143 ... 

144 

145 def pull_model(self, model: str, *, on_progress: Callable[..., Any] | None = None) -> None: 

146 """Download a model. Raises NotImplementedError if not supported.""" 

147 ... 

148 

149 def show_model(self, model: str) -> dict[str, Any] | None: 

150 """Return model metadata, or None if backend doesn't expose it.""" 

151 ... 

152 

153 def get_capabilities(self, model: str) -> list[str]: 

154 """Return capability tags (e.g. ``["completion", "vision"]``) for *model*. 

155 

156 Returns an empty list when the backend does not support capability 

157 reporting or the model is not found. 

158 """ 

159 ... 

160 

161 def rerank(self, query: str, candidates: list[str]) -> list[float]: 

162 """Score *candidates* for their relevance to *query*, one float per candidate. 

163 

164 The backend resolves the reranker model from ``cfg.reranker_model``. 

165 Callers MUST check ``cfg.reranker_model`` is non-empty before 

166 calling; use :meth:`supports_rerank` for UI-render decisions. 

167 

168 Returns: list of floats in input order, higher = more relevant. 

169 Empty ``candidates`` returns ``[]``. 

170 Raises :class:`ProviderError` when the backend does not support 

171 reranking or ``cfg.reranker_model`` is empty. 

172 """ 

173 ... 

174 

175 def supports_rerank(self) -> bool: 

176 """Capability probe: can this backend rerank *if* a model is configured? 

177 

178 Pure capability check, NOT "a reranker is currently active". An 

179 empty ``cfg.reranker_model`` returns ``True`` so the settings UI 

180 keeps the picker visible; callers that need to know whether 

181 reranking is actually configured must check ``bool(cfg.reranker_model)`` 

182 separately. ``rerank()`` is the gated path that requires a 

183 non-empty value. 

184 """ 

185 return False 

186 

187 def shutdown(self) -> None: 

188 """Release resources (e.g. background threads). No-op if nothing to clean up.""" 

189 ... 

190 

191 def invalidate_load_cache(self, model_path: Path | None = None) -> None: 

192 """Drop loaded-model state; ``None`` evicts all, else only that path. No-op default.""" 

193 return 

194 

195 def warm_up_pool(self) -> None: 

196 """Eagerly register configured roles so :meth:`WorkerPool.start_eager` has work to do. 

197 

198 Default no-op so providers without a worker pool (SDK / routing 

199 wrappers) can be passed to ``Services`` unchanged. Implemented by 

200 :class:`LlamaCppProvider` to register chat / embed / rerank / vision 

201 roles whose model is configured. 

202 """ 

203 return