Coverage for src / lilbee / retrieval / embedder.py: 100%

101 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-28 01:01 +0000

1"""Thin wrapper around LLM provider embeddings API.""" 

2 

3import logging 

4import threading 

5 

6import numpy as np 

7 

8from lilbee.core.config import Config 

9from lilbee.data.chunk import CHARS_PER_TOKEN 

10from lilbee.providers.base import LLMProvider 

11from lilbee.providers.model_ref import ProviderModelRef, parse_model_ref 

12from lilbee.runtime.progress import DetailedProgressCallback, EmbedEvent, EventType, noop_callback 

13 

14log = logging.getLogger(__name__) 

15 

16MAX_BATCH_CHARS = 6000 

17 

18 

19def _name_base(ref: ProviderModelRef) -> str: 

20 return ref.name.split(":")[0].lower().replace(" ", "-") 

21 

22 

23def _remote_sees_model(ref: ProviderModelRef, provider: LLMProvider) -> bool: 

24 try: 

25 available = provider.list_models() 

26 except Exception: 

27 log.debug("provider list_models failed during availability check", exc_info=True) 

28 return False 

29 base = _name_base(ref) 

30 return any(base in m.lower().replace(" ", "-") for m in available) 

31 

32 

33def _native_has_model(model: str) -> bool: 

34 from lilbee.providers.llama_cpp.provider import resolve_model_path 

35 

36 try: 

37 resolve_model_path(model) 

38 except Exception: 

39 return False 

40 return True 

41 

42 

43def is_model_available(model: str, provider: LLMProvider) -> bool: 

44 """Return True if *model* resolves via *provider* or the native registry. 

45 

46 Remote-prefixed refs (``ollama/`` and API providers) skip the native 

47 probe since they resolve through the SDK backend at call time. 

48 """ 

49 if not model: 

50 return False 

51 ref = parse_model_ref(model) 

52 if _remote_sees_model(ref, provider): 

53 return True 

54 if ref.is_remote: 

55 return False 

56 return _native_has_model(model) 

57 

58 

59class Embedder: 

60 """Embedding wrapper: truncates, batches, validates vectors, and counts truncations.""" 

61 

62 def __init__(self, config: Config, provider: LLMProvider) -> None: 

63 self._config = config 

64 self._provider = provider 

65 self.last_batch_truncated = 0 

66 self._truncated_total = 0 

67 self._truncated_lock = threading.Lock() 

68 

69 @property 

70 def embed_char_budget(self) -> int: 

71 """Effective char limit, never below the chunker's max chunk size. 

72 

73 ``max_embed_chars`` guards the embed model's context; clamping it up to 

74 ``chunk_size * CHARS_PER_TOKEN`` stops a finished full-budget chunk from 

75 silently losing its tail to a limit set below what the chunker emits. 

76 """ 

77 return max(self._config.max_embed_chars, self._config.chunk_size * CHARS_PER_TOKEN) 

78 

79 @property 

80 def truncated_total(self) -> int: 

81 """Cumulative count of chunks truncated since process start (thread-safe read).""" 

82 with self._truncated_lock: 

83 return self._truncated_total 

84 

85 def truncate(self, text: str) -> str: 

86 """Truncate text to the embed char budget, counting any truncation.""" 

87 budget = self.embed_char_budget 

88 if len(text) <= budget: 

89 return text 

90 log.debug("Truncating chunk from %d to %d chars for embedding", len(text), budget) 

91 with self._truncated_lock: 

92 self._truncated_total += 1 

93 return text[:budget] 

94 

95 def validate_vector(self, vector: list[float]) -> None: 

96 """Validate embedding vector dimension and values.""" 

97 if len(vector) != self._config.embedding_dim: 

98 raise ValueError( 

99 f"Embedding dimension mismatch: expected {self._config.embedding_dim}, " 

100 f"got {len(vector)}" 

101 ) 

102 arr = np.asarray(vector, dtype=np.float64) 

103 bad = np.where(~np.isfinite(arr))[0] 

104 if bad.size: 

105 i = int(bad[0]) 

106 raise ValueError(f"Embedding contains invalid value at index {i}: {vector[i]}") 

107 

108 def validate_model(self) -> bool: 

109 """Check if the configured embedding model is available. No side effects.""" 

110 return self.embedding_available() 

111 

112 def embedding_available(self) -> bool: 

113 """Return True if the embedding model can be resolved. 

114 

115 Checks the provider model list and the native registry path 

116 resolution. Returns True if either finds the model. 

117 """ 

118 return is_model_available(self._config.embedding_model, self._provider) 

119 

120 def embed(self, text: str) -> list[float]: 

121 """Embed a single text string, return vector.""" 

122 vectors = self._provider.embed([self.truncate(text)]) 

123 result: list[float] = vectors[0] 

124 self.validate_vector(result) 

125 return result 

126 

127 def embed_batch( 

128 self, 

129 texts: list[str], 

130 *, 

131 source: str = "", 

132 on_progress: DetailedProgressCallback = noop_callback, 

133 ) -> list[list[float]]: 

134 """Embed multiple texts with adaptive batching, return list of vectors. 

135 Fires ``embed`` progress events per batch when *on_progress* is provided. 

136 """ 

137 if not texts: 

138 self.last_batch_truncated = 0 

139 return [] 

140 truncated_before = self.truncated_total 

141 total_chunks = len(texts) 

142 vectors: list[list[float]] = [] 

143 batch: list[str] = [] 

144 batch_chars = 0 

145 for text in texts: 

146 truncated = self.truncate(text) 

147 chunk_len = len(truncated) 

148 if batch and batch_chars + chunk_len > MAX_BATCH_CHARS: 

149 vectors.extend(self._provider.embed(batch)) 

150 on_progress( 

151 EventType.EMBED, 

152 EmbedEvent(file=source, chunk=len(vectors), total_chunks=total_chunks), 

153 ) 

154 batch = [] 

155 batch_chars = 0 

156 batch.append(truncated) 

157 batch_chars += chunk_len 

158 if batch: 

159 vectors.extend(self._provider.embed(batch)) 

160 on_progress( 

161 EventType.EMBED, 

162 EmbedEvent(file=source, chunk=len(vectors), total_chunks=total_chunks), 

163 ) 

164 for vec in vectors: 

165 self.validate_vector(vec) 

166 self.last_batch_truncated = self.truncated_total - truncated_before 

167 return vectors