Coverage for src / lilbee / retrieval / embedder.py: 100%
101 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-28 01:01 +0000
1"""Thin wrapper around LLM provider embeddings API."""
3import logging
4import threading
6import numpy as np
8from lilbee.core.config import Config
9from lilbee.data.chunk import CHARS_PER_TOKEN
10from lilbee.providers.base import LLMProvider
11from lilbee.providers.model_ref import ProviderModelRef, parse_model_ref
12from lilbee.runtime.progress import DetailedProgressCallback, EmbedEvent, EventType, noop_callback
14log = logging.getLogger(__name__)
16MAX_BATCH_CHARS = 6000
19def _name_base(ref: ProviderModelRef) -> str:
20 return ref.name.split(":")[0].lower().replace(" ", "-")
23def _remote_sees_model(ref: ProviderModelRef, provider: LLMProvider) -> bool:
24 try:
25 available = provider.list_models()
26 except Exception:
27 log.debug("provider list_models failed during availability check", exc_info=True)
28 return False
29 base = _name_base(ref)
30 return any(base in m.lower().replace(" ", "-") for m in available)
33def _native_has_model(model: str) -> bool:
34 from lilbee.providers.llama_cpp.provider import resolve_model_path
36 try:
37 resolve_model_path(model)
38 except Exception:
39 return False
40 return True
43def is_model_available(model: str, provider: LLMProvider) -> bool:
44 """Return True if *model* resolves via *provider* or the native registry.
46 Remote-prefixed refs (``ollama/`` and API providers) skip the native
47 probe since they resolve through the SDK backend at call time.
48 """
49 if not model:
50 return False
51 ref = parse_model_ref(model)
52 if _remote_sees_model(ref, provider):
53 return True
54 if ref.is_remote:
55 return False
56 return _native_has_model(model)
59class Embedder:
60 """Embedding wrapper: truncates, batches, validates vectors, and counts truncations."""
62 def __init__(self, config: Config, provider: LLMProvider) -> None:
63 self._config = config
64 self._provider = provider
65 self.last_batch_truncated = 0
66 self._truncated_total = 0
67 self._truncated_lock = threading.Lock()
69 @property
70 def embed_char_budget(self) -> int:
71 """Effective char limit, never below the chunker's max chunk size.
73 ``max_embed_chars`` guards the embed model's context; clamping it up to
74 ``chunk_size * CHARS_PER_TOKEN`` stops a finished full-budget chunk from
75 silently losing its tail to a limit set below what the chunker emits.
76 """
77 return max(self._config.max_embed_chars, self._config.chunk_size * CHARS_PER_TOKEN)
79 @property
80 def truncated_total(self) -> int:
81 """Cumulative count of chunks truncated since process start (thread-safe read)."""
82 with self._truncated_lock:
83 return self._truncated_total
85 def truncate(self, text: str) -> str:
86 """Truncate text to the embed char budget, counting any truncation."""
87 budget = self.embed_char_budget
88 if len(text) <= budget:
89 return text
90 log.debug("Truncating chunk from %d to %d chars for embedding", len(text), budget)
91 with self._truncated_lock:
92 self._truncated_total += 1
93 return text[:budget]
95 def validate_vector(self, vector: list[float]) -> None:
96 """Validate embedding vector dimension and values."""
97 if len(vector) != self._config.embedding_dim:
98 raise ValueError(
99 f"Embedding dimension mismatch: expected {self._config.embedding_dim}, "
100 f"got {len(vector)}"
101 )
102 arr = np.asarray(vector, dtype=np.float64)
103 bad = np.where(~np.isfinite(arr))[0]
104 if bad.size:
105 i = int(bad[0])
106 raise ValueError(f"Embedding contains invalid value at index {i}: {vector[i]}")
108 def validate_model(self) -> bool:
109 """Check if the configured embedding model is available. No side effects."""
110 return self.embedding_available()
112 def embedding_available(self) -> bool:
113 """Return True if the embedding model can be resolved.
115 Checks the provider model list and the native registry path
116 resolution. Returns True if either finds the model.
117 """
118 return is_model_available(self._config.embedding_model, self._provider)
120 def embed(self, text: str) -> list[float]:
121 """Embed a single text string, return vector."""
122 vectors = self._provider.embed([self.truncate(text)])
123 result: list[float] = vectors[0]
124 self.validate_vector(result)
125 return result
127 def embed_batch(
128 self,
129 texts: list[str],
130 *,
131 source: str = "",
132 on_progress: DetailedProgressCallback = noop_callback,
133 ) -> list[list[float]]:
134 """Embed multiple texts with adaptive batching, return list of vectors.
135 Fires ``embed`` progress events per batch when *on_progress* is provided.
136 """
137 if not texts:
138 self.last_batch_truncated = 0
139 return []
140 truncated_before = self.truncated_total
141 total_chunks = len(texts)
142 vectors: list[list[float]] = []
143 batch: list[str] = []
144 batch_chars = 0
145 for text in texts:
146 truncated = self.truncate(text)
147 chunk_len = len(truncated)
148 if batch and batch_chars + chunk_len > MAX_BATCH_CHARS:
149 vectors.extend(self._provider.embed(batch))
150 on_progress(
151 EventType.EMBED,
152 EmbedEvent(file=source, chunk=len(vectors), total_chunks=total_chunks),
153 )
154 batch = []
155 batch_chars = 0
156 batch.append(truncated)
157 batch_chars += chunk_len
158 if batch:
159 vectors.extend(self._provider.embed(batch))
160 on_progress(
161 EventType.EMBED,
162 EmbedEvent(file=source, chunk=len(vectors), total_chunks=total_chunks),
163 )
164 for vec in vectors:
165 self.validate_vector(vec)
166 self.last_batch_truncated = self.truncated_total - truncated_before
167 return vectors